Python Module Index

l | x
 
l
lmeval_extended_evaluation
 
x
xlstm_jax
    xlstm_jax.common_types
    xlstm_jax.configs
    xlstm_jax.dataset
    xlstm_jax.dataset.batch
    xlstm_jax.dataset.configs
    xlstm_jax.dataset.grain_batch_rampup
    xlstm_jax.dataset.grain_data_processing
    xlstm_jax.dataset.grain_iterator
    xlstm_jax.dataset.grain_transforms
    xlstm_jax.dataset.hf_tokenizer
    xlstm_jax.dataset.input_pipeline_interface
    xlstm_jax.dataset.lmeval_dataset
    xlstm_jax.dataset.lmeval_pipeline
    xlstm_jax.dataset.multihost_dataloading
    xlstm_jax.dataset.synthetic_dataloading
    xlstm_jax.define_hydra_schemas
    xlstm_jax.distributed
    xlstm_jax.distributed.array_utils
    xlstm_jax.distributed.data_parallel
    xlstm_jax.distributed.mesh_utils
    xlstm_jax.distributed.pipeline_parallel
    xlstm_jax.distributed.single_gpu
    xlstm_jax.distributed.tensor_parallel
    xlstm_jax.distributed.xla_utils
    xlstm_jax.import_utils
    xlstm_jax.kernels
    xlstm_jax.kernels.kernel_utils
    xlstm_jax.kernels.mlstm_chunkwise
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3._triton_bw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3._triton_fw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3.triton_fwbw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice._triton_bw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice._triton_fw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice.triton_fwbw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._chunkwise_gates
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._combined_bw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._combined_fw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dK
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dQ
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dV
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_fw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._recurrent_bw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._recurrent_fw
    xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize.triton_fwbw
    xlstm_jax.kernels.mlstm_chunkwise.triton_stablef
    xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_bw
    xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_fw
    xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw
    xlstm_jax.kernels.mlstm_recurrent
    xlstm_jax.kernels.mlstm_recurrent.triton_fused_fw
    xlstm_jax.kernels.stride_utils
    xlstm_jax.main_train
    xlstm_jax.models
    xlstm_jax.models.configs
    xlstm_jax.models.llama
    xlstm_jax.models.llama.attention
    xlstm_jax.models.llama.feedforward
    xlstm_jax.models.llama.llama
    xlstm_jax.models.shared
    xlstm_jax.models.shared.init
    xlstm_jax.models.shared.lm_head
    xlstm_jax.models.shared.utils
    xlstm_jax.models.xlstm_clean
    xlstm_jax.models.xlstm_clean.blocks
    xlstm_jax.models.xlstm_clean.blocks.mlstm
    xlstm_jax.models.xlstm_clean.blocks.mlstm.backend
    xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.config
    xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.config_utils
    xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.layer_factory
    xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple
    xlstm_jax.models.xlstm_clean.blocks.mlstm.block
    xlstm_jax.models.xlstm_clean.blocks.mlstm.cell
    xlstm_jax.models.xlstm_clean.blocks.mlstm.layer
    xlstm_jax.models.xlstm_clean.blocks.xlstm_block
    xlstm_jax.models.xlstm_clean.components
    xlstm_jax.models.xlstm_clean.components.conv
    xlstm_jax.models.xlstm_clean.components.feedforward
    xlstm_jax.models.xlstm_clean.components.init
    xlstm_jax.models.xlstm_clean.components.linear_headwise
    xlstm_jax.models.xlstm_clean.components.ln
    xlstm_jax.models.xlstm_clean.utils
    xlstm_jax.models.xlstm_clean.xlstm_block_stack
    xlstm_jax.models.xlstm_clean.xlstm_lm_model
    xlstm_jax.models.xlstm_parallel
    xlstm_jax.models.xlstm_parallel.benchmark
    xlstm_jax.models.xlstm_parallel.blocks
    xlstm_jax.models.xlstm_parallel.blocks.mlstm
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config_utils
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.layer_factory
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend_utils
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.block
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer
    xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer_v1
    xlstm_jax.models.xlstm_parallel.blocks.xlstm_block
    xlstm_jax.models.xlstm_parallel.checkpointing
    xlstm_jax.models.xlstm_parallel.components
    xlstm_jax.models.xlstm_parallel.components.conv
    xlstm_jax.models.xlstm_parallel.components.feedforward
    xlstm_jax.models.xlstm_parallel.components.init
    xlstm_jax.models.xlstm_parallel.components.linear_headwise
    xlstm_jax.models.xlstm_parallel.components.normalization
    xlstm_jax.models.xlstm_parallel.training
    xlstm_jax.models.xlstm_parallel.utils
    xlstm_jax.models.xlstm_parallel.xlstm_block_stack
    xlstm_jax.models.xlstm_parallel.xlstm_lm_model
    xlstm_jax.models.xlstm_pytorch
    xlstm_jax.models.xlstm_pytorch.blocks
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.config
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.config_utils
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.layer_factory
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.simple
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.tl_utils
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.triton_chunk
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.block
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell
    xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer
    xlstm_jax.models.xlstm_pytorch.blocks.slstm
    xlstm_jax.models.xlstm_pytorch.blocks.slstm.block
    xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell
    xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer
    xlstm_jax.models.xlstm_pytorch.blocks.slstm.src
    xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.cuda_init
    xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla
    xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.lstm
    xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm
    xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block
    xlstm_jax.models.xlstm_pytorch.components
    xlstm_jax.models.xlstm_pytorch.components.conv
    xlstm_jax.models.xlstm_pytorch.components.feedforward
    xlstm_jax.models.xlstm_pytorch.components.init
    xlstm_jax.models.xlstm_pytorch.components.linear_headwise
    xlstm_jax.models.xlstm_pytorch.components.ln
    xlstm_jax.models.xlstm_pytorch.components.util
    xlstm_jax.models.xlstm_pytorch.utils
    xlstm_jax.models.xlstm_pytorch.xlstm_block_stack
    xlstm_jax.models.xlstm_pytorch.xlstm_lm_model
    xlstm_jax.resume_training
    xlstm_jax.start_training
    xlstm_jax.train_init_fns
    xlstm_jax.trainer
    xlstm_jax.trainer.base
    xlstm_jax.trainer.base.param_utils
    xlstm_jax.trainer.base.trainer
    xlstm_jax.trainer.callbacks
    xlstm_jax.trainer.callbacks.callback
    xlstm_jax.trainer.callbacks.checkpointing
    xlstm_jax.trainer.callbacks.extended_evaluation
    xlstm_jax.trainer.callbacks.lr_monitor
    xlstm_jax.trainer.callbacks.profiler
    xlstm_jax.trainer.data_module
    xlstm_jax.trainer.llm
    xlstm_jax.trainer.llm.sampling
    xlstm_jax.trainer.llm.trainer
    xlstm_jax.trainer.logger
    xlstm_jax.trainer.logger.base_logger
    xlstm_jax.trainer.logger.cmd_logging
    xlstm_jax.trainer.logger.file_logger
    xlstm_jax.trainer.logger.tensorboard_logger
    xlstm_jax.trainer.logger.wandb_logger
    xlstm_jax.trainer.metrics
    xlstm_jax.trainer.optimizer
    xlstm_jax.trainer.optimizer.ademamix
    xlstm_jax.trainer.optimizer.optimizer
    xlstm_jax.trainer.optimizer.scheduler
    xlstm_jax.utils
    xlstm_jax.utils.error_logging_utils
    xlstm_jax.utils.model_param_handling
    xlstm_jax.utils.model_param_handling.convert_checkpoint
    xlstm_jax.utils.model_param_handling.convert_state_dict
    xlstm_jax.utils.model_param_handling.handle_mlstm_simple
    xlstm_jax.utils.model_param_handling.load
    xlstm_jax.utils.model_param_handling.store
    xlstm_jax.utils.pytree_utils