- xlstm_block_stack (xlstm_jax.models.xlstm_pytorch.xlstm_lm_model.xLSTMLMModel attribute)
-
xlstm_jax
-
xlstm_jax.common_types
-
xlstm_jax.configs
-
xlstm_jax.dataset
-
xlstm_jax.dataset.batch
-
xlstm_jax.dataset.configs
-
xlstm_jax.dataset.grain_batch_rampup
-
xlstm_jax.dataset.grain_data_processing
-
xlstm_jax.dataset.grain_iterator
-
xlstm_jax.dataset.grain_transforms
-
xlstm_jax.dataset.hf_tokenizer
-
xlstm_jax.dataset.input_pipeline_interface
-
xlstm_jax.dataset.lmeval_dataset
-
xlstm_jax.dataset.lmeval_pipeline
-
xlstm_jax.dataset.multihost_dataloading
-
xlstm_jax.dataset.synthetic_dataloading
-
xlstm_jax.define_hydra_schemas
-
xlstm_jax.distributed
-
xlstm_jax.distributed.array_utils
-
xlstm_jax.distributed.data_parallel
-
xlstm_jax.distributed.mesh_utils
-
xlstm_jax.distributed.pipeline_parallel
-
xlstm_jax.distributed.single_gpu
-
xlstm_jax.distributed.tensor_parallel
-
xlstm_jax.distributed.xla_utils
-
xlstm_jax.import_utils
-
xlstm_jax.kernels
-
xlstm_jax.kernels.kernel_utils
-
xlstm_jax.kernels.mlstm_chunkwise
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3._triton_bw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3._triton_fw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3.triton_fwbw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice._triton_bw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice._triton_fw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice.triton_fwbw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._chunkwise_gates
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._combined_bw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._combined_fw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dK
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dQ
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dV
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_fw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._recurrent_bw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._recurrent_fw
-
xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize.triton_fwbw
-
xlstm_jax.kernels.mlstm_chunkwise.triton_stablef
-
xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_bw
-
xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_fw
-
xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw
-
xlstm_jax.kernels.mlstm_recurrent
-
xlstm_jax.kernels.mlstm_recurrent.triton_fused_fw
-
xlstm_jax.kernels.stride_utils
-
xlstm_jax.main_train
-
xlstm_jax.models
-
xlstm_jax.models.configs
-
xlstm_jax.models.llama
-
xlstm_jax.models.llama.attention
-
xlstm_jax.models.llama.feedforward
-
xlstm_jax.models.llama.llama
-
xlstm_jax.models.shared
-
xlstm_jax.models.shared.init
-
xlstm_jax.models.shared.lm_head
-
xlstm_jax.models.shared.utils
-
xlstm_jax.models.xlstm_clean
-
xlstm_jax.models.xlstm_clean.blocks
-
xlstm_jax.models.xlstm_clean.blocks.mlstm
-
xlstm_jax.models.xlstm_clean.blocks.mlstm.backend
-
xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.config
-
xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.config_utils
-
xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.layer_factory
-
xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple
-
xlstm_jax.models.xlstm_clean.blocks.mlstm.block
-
xlstm_jax.models.xlstm_clean.blocks.mlstm.cell
-
xlstm_jax.models.xlstm_clean.blocks.mlstm.layer
-
xlstm_jax.models.xlstm_clean.blocks.xlstm_block
-
xlstm_jax.models.xlstm_clean.components
-
xlstm_jax.models.xlstm_clean.components.conv
-
xlstm_jax.models.xlstm_clean.components.feedforward
-
xlstm_jax.models.xlstm_clean.components.init
-
xlstm_jax.models.xlstm_clean.components.linear_headwise
-
xlstm_jax.models.xlstm_clean.components.ln
-
xlstm_jax.models.xlstm_clean.utils
-
xlstm_jax.models.xlstm_clean.xlstm_block_stack
-
xlstm_jax.models.xlstm_clean.xlstm_lm_model
-
xlstm_jax.models.xlstm_parallel
-
xlstm_jax.models.xlstm_parallel.benchmark
-
xlstm_jax.models.xlstm_parallel.blocks
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config_utils
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.layer_factory
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple
|
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend_utils
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.block
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer
-
xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer_v1
-
xlstm_jax.models.xlstm_parallel.blocks.xlstm_block
-
xlstm_jax.models.xlstm_parallel.checkpointing
-
xlstm_jax.models.xlstm_parallel.components
-
xlstm_jax.models.xlstm_parallel.components.conv
-
xlstm_jax.models.xlstm_parallel.components.feedforward
-
xlstm_jax.models.xlstm_parallel.components.init
-
xlstm_jax.models.xlstm_parallel.components.linear_headwise
-
xlstm_jax.models.xlstm_parallel.components.normalization
-
xlstm_jax.models.xlstm_parallel.training
-
xlstm_jax.models.xlstm_parallel.utils
-
xlstm_jax.models.xlstm_parallel.xlstm_block_stack
-
xlstm_jax.models.xlstm_parallel.xlstm_lm_model
-
xlstm_jax.models.xlstm_pytorch
-
xlstm_jax.models.xlstm_pytorch.blocks
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.config
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.config_utils
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.layer_factory
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.simple
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.tl_utils
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.triton_chunk
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.block
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell
-
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm.block
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm.src
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.cuda_init
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.lstm
-
xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm
-
xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block
-
xlstm_jax.models.xlstm_pytorch.components
-
xlstm_jax.models.xlstm_pytorch.components.conv
-
xlstm_jax.models.xlstm_pytorch.components.feedforward
-
xlstm_jax.models.xlstm_pytorch.components.init
-
xlstm_jax.models.xlstm_pytorch.components.linear_headwise
-
xlstm_jax.models.xlstm_pytorch.components.ln
-
xlstm_jax.models.xlstm_pytorch.components.util
-
xlstm_jax.models.xlstm_pytorch.utils
-
xlstm_jax.models.xlstm_pytorch.xlstm_block_stack
-
xlstm_jax.models.xlstm_pytorch.xlstm_lm_model
-
xlstm_jax.resume_training
-
xlstm_jax.start_training
-
xlstm_jax.train_init_fns
-
xlstm_jax.trainer
-
xlstm_jax.trainer.base
-
xlstm_jax.trainer.base.param_utils
-
xlstm_jax.trainer.base.trainer
-
xlstm_jax.trainer.callbacks
-
xlstm_jax.trainer.callbacks.callback
-
xlstm_jax.trainer.callbacks.checkpointing
-
xlstm_jax.trainer.callbacks.extended_evaluation
-
xlstm_jax.trainer.callbacks.lr_monitor
-
xlstm_jax.trainer.callbacks.profiler
-
xlstm_jax.trainer.data_module
-
xlstm_jax.trainer.llm
-
xlstm_jax.trainer.llm.sampling
-
xlstm_jax.trainer.llm.trainer
-
xlstm_jax.trainer.logger
-
xlstm_jax.trainer.logger.base_logger
-
xlstm_jax.trainer.logger.cmd_logging
-
xlstm_jax.trainer.logger.file_logger
-
xlstm_jax.trainer.logger.tensorboard_logger
-
xlstm_jax.trainer.logger.wandb_logger
-
xlstm_jax.trainer.metrics
-
xlstm_jax.trainer.optimizer
-
xlstm_jax.trainer.optimizer.ademamix
-
xlstm_jax.trainer.optimizer.optimizer
-
xlstm_jax.trainer.optimizer.scheduler
-
xlstm_jax.utils
-
xlstm_jax.utils.error_logging_utils
-
xlstm_jax.utils.model_param_handling
-
xlstm_jax.utils.model_param_handling.convert_checkpoint
-
xlstm_jax.utils.model_param_handling.convert_state_dict
-
xlstm_jax.utils.model_param_handling.handle_mlstm_simple
-
xlstm_jax.utils.model_param_handling.load
-
xlstm_jax.utils.model_param_handling.store
-
xlstm_jax.utils.pytree_utils
- xlstm_norm (xlstm_jax.models.xlstm_pytorch.blocks.mlstm.block.mLSTMBlock attribute)
- xLSTMBlock (class in xlstm_jax.models.xlstm_clean.blocks.xlstm_block)
- xLSTMBlockConfig (class in xlstm_jax.models.xlstm_clean.blocks.xlstm_block)
- xLSTMBlockStack (class in xlstm_jax.models.xlstm_clean.xlstm_block_stack)
- xLSTMBlockStackConfig (class in xlstm_jax.models.xlstm_clean.xlstm_block_stack)
- xLSTMLMModel (class in xlstm_jax.models.xlstm_clean.xlstm_lm_model)
- xLSTMLMModelConfig (class in xlstm_jax.models.xlstm_clean.xlstm_lm_model)
|