xlstm_jax.main_train

xlstm_jax.main_train#

Functions#

main_train(cfg[, checkpoint_step, load_dataloaders, ...])

The main training function. This function initializes the mesh, data iterators,

Module Contents#

xlstm_jax.main_train.main_train(cfg, checkpoint_step=None, load_dataloaders=True, load_optimizer=True)#
The main training function. This function initializes the mesh, data iterators,

model config, and trainer and then starts training. Can be optionally started from a checkpoint, in which case the training state is loaded from the checkpoint with the supplied step index.

In order to see error logs in our custom logger, we use the with_error_handling

decorator.

Parameters:
  • cfg (omegaconf.DictConfig) – The full configuration.

  • checkpoint_step (optional) – Step index of checkpoint to be loaded. Defaults to None, in which case training starts from scratch.

  • load_dataloaders (optional) – Whether to load the data loaders. Defaults to True.

  • load_optimizer (optional) – Whether to load the optimizer. Defaults to True.

Returns:

The final metrics of the training.

Return type:

dict[str, Any]