Skip to main content
Ctrl+K
xlstm-jax  documentation - Home xlstm-jax  documentation - Home
  • Installation
  • Dataset Preparation
  • Training large language models in xlstm-jax
  • Configuring Experiments with Hydra
  • Distributed Training
  • API Reference
    • xlstm_jax
      • xlstm_jax.common_types
      • xlstm_jax.configs
      • xlstm_jax.dataset
        • xlstm_jax.dataset.batch
        • xlstm_jax.dataset.configs
        • xlstm_jax.dataset.grain_batch_rampup
        • xlstm_jax.dataset.grain_data_processing
        • xlstm_jax.dataset.grain_iterator
        • xlstm_jax.dataset.grain_transforms
        • xlstm_jax.dataset.hf_tokenizer
        • xlstm_jax.dataset.input_pipeline_interface
        • xlstm_jax.dataset.lmeval_dataset
        • xlstm_jax.dataset.lmeval_pipeline
        • xlstm_jax.dataset.multihost_dataloading
        • xlstm_jax.dataset.synthetic_dataloading
      • xlstm_jax.define_hydra_schemas
      • xlstm_jax.distributed
        • xlstm_jax.distributed.array_utils
        • xlstm_jax.distributed.data_parallel
        • xlstm_jax.distributed.mesh_utils
        • xlstm_jax.distributed.pipeline_parallel
        • xlstm_jax.distributed.single_gpu
        • xlstm_jax.distributed.tensor_parallel
        • xlstm_jax.distributed.xla_utils
      • xlstm_jax.import_utils
      • xlstm_jax.kernels
        • xlstm_jax.kernels.kernel_utils
        • xlstm_jax.kernels.mlstm_chunkwise
        • xlstm_jax.kernels.mlstm_recurrent
        • xlstm_jax.kernels.stride_utils
      • xlstm_jax.main_train
      • xlstm_jax.models
        • xlstm_jax.models.configs
        • xlstm_jax.models.llama
        • xlstm_jax.models.shared
        • xlstm_jax.models.xlstm_clean
        • xlstm_jax.models.xlstm_parallel
        • xlstm_jax.models.xlstm_pytorch
      • xlstm_jax.resume_training
      • xlstm_jax.start_training
      • xlstm_jax.train_init_fns
      • xlstm_jax.trainer
        • xlstm_jax.trainer.base
        • xlstm_jax.trainer.callbacks
        • xlstm_jax.trainer.data_module
        • xlstm_jax.trainer.llm
        • xlstm_jax.trainer.logger
        • xlstm_jax.trainer.metrics
        • xlstm_jax.trainer.optimizer
      • xlstm_jax.utils
        • xlstm_jax.utils.error_logging_utils
        • xlstm_jax.utils.model_param_handling
        • xlstm_jax.utils.pytree_utils
    • lmeval_extended_evaluation
  • .rst

xlstm_jax

Contents

  • Submodules

xlstm_jax#

Submodules#

  • xlstm_jax.common_types
  • xlstm_jax.configs
  • xlstm_jax.dataset
  • xlstm_jax.define_hydra_schemas
  • xlstm_jax.distributed
  • xlstm_jax.import_utils
  • xlstm_jax.kernels
  • xlstm_jax.main_train
  • xlstm_jax.models
  • xlstm_jax.resume_training
  • xlstm_jax.start_training
  • xlstm_jax.train_init_fns
  • xlstm_jax.trainer
  • xlstm_jax.utils

previous

API Reference

next

xlstm_jax.common_types

Contents
  • Submodules

By NXAI GmbH

© Copyright 2024, NXAI GmbH.