Skip to main content
Ctrl+K
xlstm-jax  documentation - Home xlstm-jax  documentation - Home
  • Installation
  • Dataset Preparation
  • Training large language models in xlstm-jax
  • Configuring Experiments with Hydra
  • Distributed Training
  • API Reference
    • xlstm_jax
      • xlstm_jax.common_types
      • xlstm_jax.configs
      • xlstm_jax.dataset
        • xlstm_jax.dataset.batch
        • xlstm_jax.dataset.configs
        • xlstm_jax.dataset.grain_batch_rampup
        • xlstm_jax.dataset.grain_data_processing
        • xlstm_jax.dataset.grain_iterator
        • xlstm_jax.dataset.grain_transforms
        • xlstm_jax.dataset.hf_tokenizer
        • xlstm_jax.dataset.input_pipeline_interface
        • xlstm_jax.dataset.lmeval_dataset
        • xlstm_jax.dataset.lmeval_pipeline
        • xlstm_jax.dataset.multihost_dataloading
        • xlstm_jax.dataset.synthetic_dataloading
      • xlstm_jax.define_hydra_schemas
      • xlstm_jax.distributed
        • xlstm_jax.distributed.array_utils
        • xlstm_jax.distributed.data_parallel
        • xlstm_jax.distributed.mesh_utils
        • xlstm_jax.distributed.pipeline_parallel
        • xlstm_jax.distributed.single_gpu
        • xlstm_jax.distributed.tensor_parallel
        • xlstm_jax.distributed.xla_utils
      • xlstm_jax.import_utils
      • xlstm_jax.kernels
        • xlstm_jax.kernels.kernel_utils
        • xlstm_jax.kernels.mlstm_chunkwise
        • xlstm_jax.kernels.mlstm_recurrent
        • xlstm_jax.kernels.stride_utils
      • xlstm_jax.main_train
      • xlstm_jax.models
        • xlstm_jax.models.configs
        • xlstm_jax.models.llama
        • xlstm_jax.models.shared
        • xlstm_jax.models.xlstm_clean
        • xlstm_jax.models.xlstm_parallel
        • xlstm_jax.models.xlstm_pytorch
      • xlstm_jax.resume_training
      • xlstm_jax.start_training
      • xlstm_jax.train_init_fns
      • xlstm_jax.trainer
        • xlstm_jax.trainer.base
        • xlstm_jax.trainer.callbacks
        • xlstm_jax.trainer.data_module
        • xlstm_jax.trainer.llm
        • xlstm_jax.trainer.logger
        • xlstm_jax.trainer.metrics
        • xlstm_jax.trainer.optimizer
      • xlstm_jax.utils
        • xlstm_jax.utils.error_logging_utils
        • xlstm_jax.utils.model_param_handling
        • xlstm_jax.utils.pytree_utils
    • lmeval_extended_evaluation
  • .rst

xlstm-jax documentation

Contents

  • Indices and tables

xlstm-jax documentation#

  • Installation
    • Repository Installation
    • Conda environment
  • Dataset Preparation
    • Convert Huggingface Datasets to ArrayRecord
    • Splitting DCLM dataset
    • Preprocess Validation Datasets
    • DCLM dataset
  • Training large language models in xlstm-jax
    • Training without Hydra
    • Training with a Hydra configuration
      • Default configurations
      • Experiment configuration
  • Configuring Experiments with Hydra
    • Configuration Structure
      • The Config Dataclasses
      • The Config yaml files
    • How to Run Experiments
      • Using Experiment Files
      • Type Checking of the Configurations
    • How to run experiments on a SLURM cluster
    • How to Resume an Experiment?
  • Distributed Training
    • Distributed Computing in JAX
      • JIT vs Shard Map
      • Multi-Host Training
    • Parallelization Strategies
      • Data Parallelism
      • Fully-Sharded Data Parallelism
      • Pipeline Parallelism
      • Tensor Parallelism
      • Miscellaneous
  • API Reference
    • xlstm_jax
      • Submodules
    • lmeval_extended_evaluation
      • Attributes
      • Classes
      • Functions
      • Module Contents

Indices and tables#

  • Index

  • Module Index

  • Search Page

next

Installation

Contents
  • Indices and tables

By NXAI GmbH

© Copyright 2024, NXAI GmbH.