xlstm_jax.main_train#
Functions#
|
The main training function. This function initializes the mesh, data iterators, |
Module Contents#
- xlstm_jax.main_train.main_train(cfg, checkpoint_step=None, load_dataloaders=True, load_optimizer=True)#
- The main training function. This function initializes the mesh, data iterators,
model config, and trainer and then starts training. Can be optionally started from a checkpoint, in which case the training state is loaded from the checkpoint with the supplied step index.
- In order to see error logs in our custom logger, we use the with_error_handling
decorator.
- Parameters:
cfg (omegaconf.DictConfig) – The full configuration.
checkpoint_step (optional) – Step index of checkpoint to be loaded. Defaults to None, in which case training starts from scratch.
load_dataloaders (optional) – Whether to load the data loaders. Defaults to True.
load_optimizer (optional) – Whether to load the optimizer. Defaults to True.
- Returns:
The final metrics of the training.
- Return type: