xlstm_jax.train_init_fns#

Attributes#

Functions#

init_parallel(cfg)

Initialize configuration for parallelism.

init_data_iterator(cfg, mesh)

Initialize data iterators.

init_single_data_iterator(cfg, mesh[, create_split])

Initialize a single data iterator.

init_mixed_data_iterator(cfg, mesh)

Initialize a data iterator with mixed data sources.

get_tokenizer_vocab_size(cfg[, next_multiple_of])

Get the vocabulary size from the tokenizer.

init_model_config(cfg, parallel)

Instantiate the model configuration.

init_logger_config(cfg)

Instantiate logger configuration.

init_scheduler_config(cfg, data_iterator)

Instantiate scheduler configuration.

init_optimizer_config(cfg)

Instantiate optimizer configuration.

init_model_checkpointing(cfg)

Instantiate model checkpointing configuration.

init_lr_monitor_config(cfg)

Instantiate learning rate monitor configuration.

init_profiler_config(cfg)

Instantiate profiler configuration.

init_trainer(cfg, data_iterator, model_config, mesh)

Initializes the LLMTrainer with all sub-configs.

log_info(msg)

Module Contents#

xlstm_jax.train_init_fns.LOGGER#
xlstm_jax.train_init_fns.init_parallel(cfg)#

Initialize configuration for parallelism.

Parameters:
  • Hydra. (cfg Config assembled by)

  • cfg (omegaconf.DictConfig)

Returns:

Initialized parallel configuration.

Return type:

xlstm_jax.models.configs.ParallelConfig

xlstm_jax.train_init_fns.init_data_iterator(cfg, mesh)#

Initialize data iterators.

Parameters:
  • cfg (omegaconf.DictConfig) – Config assembled by Hydra.

  • mesh (jax.sharding.Mesh) – The jax device mesh.

Returns:

Training and evaluation data iterators.

Return type:

tuple[xlstm_jax.dataset.DataIterator, xlstm_jax.dataset.DataIterator | dict[str, xlstm_jax.dataset.DataIterator] | None]

xlstm_jax.train_init_fns.init_single_data_iterator(cfg, mesh, create_split=None)#

Initialize a single data iterator.

Parameters:
  • cfg (omegaconf.DictConfig) – Data configuration.

  • mesh (jax.sharding.Mesh) – The jax device mesh.

  • create_split (Literal['train', 'eval', None]) – Whether to create a train or eval config from the config class, using the create_train_eval_configs method. If None, the config is used as is.

Returns:

Data iterator.

Return type:

xlstm_jax.dataset.DataIterator

xlstm_jax.train_init_fns.init_mixed_data_iterator(cfg, mesh)#

Initialize a data iterator with mixed data sources.

Parameters:
  • cfg (omegaconf.DictConfig) – Data configuration.

  • mesh (jax.sharding.Mesh) – The jax device mesh.

Returns:

Data iterator.

Return type:

xlstm_jax.dataset.DataIterator

xlstm_jax.train_init_fns.get_tokenizer_vocab_size(cfg, next_multiple_of=1)#

Get the vocabulary size from the tokenizer.

Parameters:
  • cfg (omegaconf.DictConfig) – Config assembled by Hydra.

  • next_multiple_of (int) – The vocabulary size will be increased to the next multiple of this number.

Returns:

The vocabulary size, increased to the next multiple of next_multiple_of.

Return type:

int

xlstm_jax.train_init_fns.init_model_config(cfg, parallel)#

Instantiate the model configuration.

Parameters:
Returns:

Initialized model configuration.

Return type:

xlstm_jax.models.ModelConfig

xlstm_jax.train_init_fns.init_logger_config(cfg)#

Instantiate logger configuration.

Parameters:

cfg (omegaconf.DictConfig) – Config assembled by Hydra.

Returns:

Instance of LoggerConfig.

Return type:

xlstm_jax.trainer.logger.LoggerConfig

xlstm_jax.train_init_fns.init_scheduler_config(cfg, data_iterator)#

Instantiate scheduler configuration.

Parameters:
  • data_iterator (xlstm_jax.dataset.DataIterator)

  • cfg (omegaconf.DictConfig) – Config assembled by Hydra.

Returns:

Instance of SchedulerConfig following the provided config.

Return type:

xlstm_jax.trainer.optimizer.SchedulerConfig

xlstm_jax.train_init_fns.init_optimizer_config(cfg)#

Instantiate optimizer configuration.

Parameters:

cfg (omegaconf.DictConfig) – Full Hydra config.

Returns:

Instance of OptimizerConfig.

Return type:

xlstm_jax.trainer.optimizer.OptimizerConfig

xlstm_jax.train_init_fns.init_model_checkpointing(cfg)#

Instantiate model checkpointing configuration.

Parameters:

cfg (omegaconf.DictConfig) – Full Hydra config.

Returns:

Instance of ModelCheckpointConfig.

Return type:

xlstm_jax.trainer.callbacks.ModelCheckpointConfig

xlstm_jax.train_init_fns.init_lr_monitor_config(cfg)#

Instantiate learning rate monitor configuration.

Parameters:

cfg (omegaconf.DictConfig) – Full Hydra config.

Returns:

Instance of LearningRateMonitorConfig.

Return type:

xlstm_jax.trainer.callbacks.LearningRateMonitorConfig

xlstm_jax.train_init_fns.init_profiler_config(cfg)#

Instantiate profiler configuration.

Parameters:

cfg (omegaconf.DictConfig) – Full Hydra config.

Returns:

Instance of JaxProfilerConfig.

Return type:

xlstm_jax.trainer.callbacks.JaxProfilerConfig

xlstm_jax.train_init_fns.init_trainer(cfg, data_iterator, model_config, mesh)#

Initializes the LLMTrainer with all sub-configs.

Parameters:
  • cfg (omegaconf.DictConfig) – Full Hydra config.

  • data_iterator (xlstm_jax.dataset.DataIterator) – A data iterator.

  • model_config (xlstm_jax.models.ModelConfig) – A model config.

  • mesh (jax.sharding.Mesh) – A device mesh.

Returns:

Instance of LLM trainer.

Return type:

xlstm_jax.trainer.llm.trainer.LLMTrainer

xlstm_jax.train_init_fns.log_info(msg)#
Parameters:

msg (str)