Python Module Index l | x l lmeval_extended_evaluation x xlstm_jax xlstm_jax.common_types xlstm_jax.configs xlstm_jax.dataset xlstm_jax.dataset.batch xlstm_jax.dataset.configs xlstm_jax.dataset.grain_batch_rampup xlstm_jax.dataset.grain_data_processing xlstm_jax.dataset.grain_iterator xlstm_jax.dataset.grain_transforms xlstm_jax.dataset.hf_tokenizer xlstm_jax.dataset.input_pipeline_interface xlstm_jax.dataset.lmeval_dataset xlstm_jax.dataset.lmeval_pipeline xlstm_jax.dataset.multihost_dataloading xlstm_jax.dataset.synthetic_dataloading xlstm_jax.define_hydra_schemas xlstm_jax.distributed xlstm_jax.distributed.array_utils xlstm_jax.distributed.data_parallel xlstm_jax.distributed.mesh_utils xlstm_jax.distributed.pipeline_parallel xlstm_jax.distributed.single_gpu xlstm_jax.distributed.tensor_parallel xlstm_jax.distributed.xla_utils xlstm_jax.import_utils xlstm_jax.kernels xlstm_jax.kernels.kernel_utils xlstm_jax.kernels.mlstm_chunkwise xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3 xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3._triton_bw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3._triton_fw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3.triton_fwbw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice._triton_bw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice._triton_fw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v3noslice.triton_fwbw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._chunkwise_gates xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._combined_bw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._combined_fw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dK xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dQ xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_bw_dV xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._parallel_fw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._recurrent_bw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize._recurrent_fw xlstm_jax.kernels.mlstm_chunkwise.max_triton_fwbw_v5xlchunksize.triton_fwbw xlstm_jax.kernels.mlstm_chunkwise.triton_stablef xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_bw xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_fw xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw xlstm_jax.kernels.mlstm_recurrent xlstm_jax.kernels.mlstm_recurrent.triton_fused_fw xlstm_jax.kernels.stride_utils xlstm_jax.main_train xlstm_jax.models xlstm_jax.models.configs xlstm_jax.models.llama xlstm_jax.models.llama.attention xlstm_jax.models.llama.feedforward xlstm_jax.models.llama.llama xlstm_jax.models.shared xlstm_jax.models.shared.init xlstm_jax.models.shared.lm_head xlstm_jax.models.shared.utils xlstm_jax.models.xlstm_clean xlstm_jax.models.xlstm_clean.blocks xlstm_jax.models.xlstm_clean.blocks.mlstm xlstm_jax.models.xlstm_clean.blocks.mlstm.backend xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.config xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.config_utils xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.layer_factory xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple xlstm_jax.models.xlstm_clean.blocks.mlstm.block xlstm_jax.models.xlstm_clean.blocks.mlstm.cell xlstm_jax.models.xlstm_clean.blocks.mlstm.layer xlstm_jax.models.xlstm_clean.blocks.xlstm_block xlstm_jax.models.xlstm_clean.components xlstm_jax.models.xlstm_clean.components.conv xlstm_jax.models.xlstm_clean.components.feedforward xlstm_jax.models.xlstm_clean.components.init xlstm_jax.models.xlstm_clean.components.linear_headwise xlstm_jax.models.xlstm_clean.components.ln xlstm_jax.models.xlstm_clean.utils xlstm_jax.models.xlstm_clean.xlstm_block_stack xlstm_jax.models.xlstm_clean.xlstm_lm_model xlstm_jax.models.xlstm_parallel xlstm_jax.models.xlstm_parallel.benchmark xlstm_jax.models.xlstm_parallel.blocks xlstm_jax.models.xlstm_parallel.blocks.mlstm xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config_utils xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.layer_factory xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend_utils xlstm_jax.models.xlstm_parallel.blocks.mlstm.block xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer_v1 xlstm_jax.models.xlstm_parallel.blocks.xlstm_block xlstm_jax.models.xlstm_parallel.checkpointing xlstm_jax.models.xlstm_parallel.components xlstm_jax.models.xlstm_parallel.components.conv xlstm_jax.models.xlstm_parallel.components.feedforward xlstm_jax.models.xlstm_parallel.components.init xlstm_jax.models.xlstm_parallel.components.linear_headwise xlstm_jax.models.xlstm_parallel.components.normalization xlstm_jax.models.xlstm_parallel.training xlstm_jax.models.xlstm_parallel.utils xlstm_jax.models.xlstm_parallel.xlstm_block_stack xlstm_jax.models.xlstm_parallel.xlstm_lm_model xlstm_jax.models.xlstm_pytorch xlstm_jax.models.xlstm_pytorch.blocks xlstm_jax.models.xlstm_pytorch.blocks.mlstm xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.config xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.config_utils xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.layer_factory xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.simple xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.tl_utils xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.triton_chunk xlstm_jax.models.xlstm_pytorch.blocks.mlstm.block xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer xlstm_jax.models.xlstm_pytorch.blocks.slstm xlstm_jax.models.xlstm_pytorch.blocks.slstm.block xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer xlstm_jax.models.xlstm_pytorch.blocks.slstm.src xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.cuda_init xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.lstm xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block xlstm_jax.models.xlstm_pytorch.components xlstm_jax.models.xlstm_pytorch.components.conv xlstm_jax.models.xlstm_pytorch.components.feedforward xlstm_jax.models.xlstm_pytorch.components.init xlstm_jax.models.xlstm_pytorch.components.linear_headwise xlstm_jax.models.xlstm_pytorch.components.ln xlstm_jax.models.xlstm_pytorch.components.util xlstm_jax.models.xlstm_pytorch.utils xlstm_jax.models.xlstm_pytorch.xlstm_block_stack xlstm_jax.models.xlstm_pytorch.xlstm_lm_model xlstm_jax.resume_training xlstm_jax.start_training xlstm_jax.train_init_fns xlstm_jax.trainer xlstm_jax.trainer.base xlstm_jax.trainer.base.param_utils xlstm_jax.trainer.base.trainer xlstm_jax.trainer.callbacks xlstm_jax.trainer.callbacks.callback xlstm_jax.trainer.callbacks.checkpointing xlstm_jax.trainer.callbacks.extended_evaluation xlstm_jax.trainer.callbacks.lr_monitor xlstm_jax.trainer.callbacks.profiler xlstm_jax.trainer.data_module xlstm_jax.trainer.llm xlstm_jax.trainer.llm.sampling xlstm_jax.trainer.llm.trainer xlstm_jax.trainer.logger xlstm_jax.trainer.logger.base_logger xlstm_jax.trainer.logger.cmd_logging xlstm_jax.trainer.logger.file_logger xlstm_jax.trainer.logger.tensorboard_logger xlstm_jax.trainer.logger.wandb_logger xlstm_jax.trainer.metrics xlstm_jax.trainer.optimizer xlstm_jax.trainer.optimizer.ademamix xlstm_jax.trainer.optimizer.optimizer xlstm_jax.trainer.optimizer.scheduler xlstm_jax.utils xlstm_jax.utils.error_logging_utils xlstm_jax.utils.model_param_handling xlstm_jax.utils.model_param_handling.convert_checkpoint xlstm_jax.utils.model_param_handling.convert_state_dict xlstm_jax.utils.model_param_handling.handle_mlstm_simple xlstm_jax.utils.model_param_handling.load xlstm_jax.utils.model_param_handling.store xlstm_jax.utils.pytree_utils