xlstm_jax.train_init_fns#
Attributes#
Functions#
|
Initialize configuration for parallelism. |
|
Initialize data iterators. |
|
Initialize a single data iterator. |
|
Initialize a data iterator with mixed data sources. |
|
Get the vocabulary size from the tokenizer. |
|
Instantiate the model configuration. |
|
Instantiate logger configuration. |
|
Instantiate scheduler configuration. |
Instantiate optimizer configuration. |
|
Instantiate model checkpointing configuration. |
|
Instantiate learning rate monitor configuration. |
|
|
Instantiate profiler configuration. |
|
Initializes the LLMTrainer with all sub-configs. |
|
Module Contents#
- xlstm_jax.train_init_fns.LOGGER#
- xlstm_jax.train_init_fns.init_parallel(cfg)#
Initialize configuration for parallelism.
- Parameters:
Hydra. (cfg Config assembled by)
cfg (omegaconf.DictConfig)
- Returns:
Initialized parallel configuration.
- Return type:
- xlstm_jax.train_init_fns.init_data_iterator(cfg, mesh)#
Initialize data iterators.
- Parameters:
cfg (omegaconf.DictConfig) – Config assembled by Hydra.
mesh (jax.sharding.Mesh) – The jax device mesh.
- Returns:
Training and evaluation data iterators.
- Return type:
tuple[xlstm_jax.dataset.DataIterator, xlstm_jax.dataset.DataIterator | dict[str, xlstm_jax.dataset.DataIterator] | None]
- xlstm_jax.train_init_fns.init_single_data_iterator(cfg, mesh, create_split=None)#
Initialize a single data iterator.
- Parameters:
cfg (omegaconf.DictConfig) – Data configuration.
mesh (jax.sharding.Mesh) – The jax device mesh.
create_split (Literal['train', 'eval', None]) – Whether to create a train or eval config from the config class, using the create_train_eval_configs method. If None, the config is used as is.
- Returns:
Data iterator.
- Return type:
xlstm_jax.dataset.DataIterator
- xlstm_jax.train_init_fns.init_mixed_data_iterator(cfg, mesh)#
Initialize a data iterator with mixed data sources.
- Parameters:
cfg (omegaconf.DictConfig) – Data configuration.
mesh (jax.sharding.Mesh) – The jax device mesh.
- Returns:
Data iterator.
- Return type:
xlstm_jax.dataset.DataIterator
- xlstm_jax.train_init_fns.get_tokenizer_vocab_size(cfg, next_multiple_of=1)#
Get the vocabulary size from the tokenizer.
- xlstm_jax.train_init_fns.init_model_config(cfg, parallel)#
Instantiate the model configuration.
- Parameters:
cfg (omegaconf.DictConfig) – Config assembled by Hydra.
parallel (xlstm_jax.models.configs.ParallelConfig) – Parallel configuration.
- Returns:
Initialized model configuration.
- Return type:
xlstm_jax.models.ModelConfig
- xlstm_jax.train_init_fns.init_logger_config(cfg)#
Instantiate logger configuration.
- Parameters:
cfg (omegaconf.DictConfig) – Config assembled by Hydra.
- Returns:
Instance of LoggerConfig.
- Return type:
xlstm_jax.trainer.logger.LoggerConfig
- xlstm_jax.train_init_fns.init_scheduler_config(cfg, data_iterator)#
Instantiate scheduler configuration.
- Parameters:
data_iterator (xlstm_jax.dataset.DataIterator)
cfg (omegaconf.DictConfig) – Config assembled by Hydra.
- Returns:
Instance of SchedulerConfig following the provided config.
- Return type:
xlstm_jax.trainer.optimizer.SchedulerConfig
- xlstm_jax.train_init_fns.init_optimizer_config(cfg)#
Instantiate optimizer configuration.
- Parameters:
cfg (omegaconf.DictConfig) – Full Hydra config.
- Returns:
Instance of OptimizerConfig.
- Return type:
xlstm_jax.trainer.optimizer.OptimizerConfig
- xlstm_jax.train_init_fns.init_model_checkpointing(cfg)#
Instantiate model checkpointing configuration.
- Parameters:
cfg (omegaconf.DictConfig) – Full Hydra config.
- Returns:
Instance of ModelCheckpointConfig.
- Return type:
xlstm_jax.trainer.callbacks.ModelCheckpointConfig
- xlstm_jax.train_init_fns.init_lr_monitor_config(cfg)#
Instantiate learning rate monitor configuration.
- Parameters:
cfg (omegaconf.DictConfig) – Full Hydra config.
- Returns:
Instance of LearningRateMonitorConfig.
- Return type:
xlstm_jax.trainer.callbacks.LearningRateMonitorConfig
- xlstm_jax.train_init_fns.init_profiler_config(cfg)#
Instantiate profiler configuration.
- Parameters:
cfg (omegaconf.DictConfig) – Full Hydra config.
- Returns:
Instance of JaxProfilerConfig.
- Return type:
xlstm_jax.trainer.callbacks.JaxProfilerConfig
- xlstm_jax.train_init_fns.init_trainer(cfg, data_iterator, model_config, mesh)#
Initializes the LLMTrainer with all sub-configs.
- Parameters:
cfg (omegaconf.DictConfig) – Full Hydra config.
data_iterator (xlstm_jax.dataset.DataIterator) – A data iterator.
model_config (xlstm_jax.models.ModelConfig) – A model config.
mesh (jax.sharding.Mesh) – A device mesh.
- Returns:
Instance of LLM trainer.
- Return type: