xlstm_jax.trainer.callbacks.checkpointing#

Attributes#

Classes#

ModelCheckpointConfig

Configuration for the ModelCheckpoint callback.

ModelCheckpoint

Callback to save model parameters and mutable variables to the logging directory.

Functions#

load_pretrained_model(checkpoint_path, trainer[, ...])

Loads a pretrained model from a checkpoint.

Module Contents#

xlstm_jax.trainer.callbacks.checkpointing.LOGGER#
class xlstm_jax.trainer.callbacks.checkpointing.ModelCheckpointConfig#

Bases: xlstm_jax.trainer.callbacks.callback.CallbackConfig

Configuration for the ModelCheckpoint callback.

By default, the checkpoint saves the model parameters, training step, random number generator state, and metadata to the logging directory. The metadata includes the trainer, model, and optimizer configurations.

max_to_keep#

Number of checkpoints to keep. If None, keeps all checkpoints. Otherwise, keeps the most recent max_to_keep checkpoints. If monitor is set, keeps the best max_to_keep checkpoints instead of the most recent.

monitor#

Metric to monitor for saving the model. Should be a key of the evaluation metrics. If None, checkpoints are sorted by recency.

mode#

One of {“min”, “max”}. If “min”, saves the model with the smallest value of the monitored metric. If “max”, saves the model with the largest value of the monitored metric.

save_optimizer_state#

Whether to save the optimizer state.

save_dataloader_state#

Whether to save the dataloader state.

enable_async_checkpointing#

Whether to enable asynchronous checkpointing. See orbax documentation for more information.

log_path#

Path to save the checkpoints as subfolder to. If None, saves to the logging directory of the trainer.

max_to_keep: int | None = 1#
monitor: str | None = None#
mode: str = 'min'#
save_optimizer_state: bool = True#
save_dataloader_state: bool = True#
enable_async_checkpointing: bool = True#
log_path: pathlib.Path | None = None#
create(trainer, data_module=None)#

Creates the ModelCheckpoint callback.

Parameters:
  • trainer (Any) – Trainer object.

  • data_module (optional) – Data module object.

Returns:

ModelCheckpoint object.

Return type:

ModelCheckpoint

class xlstm_jax.trainer.callbacks.checkpointing.ModelCheckpoint(config, trainer, data_module=None)#

Bases: xlstm_jax.trainer.callbacks.callback.Callback

Callback to save model parameters and mutable variables to the logging directory.

Sets up an orbax checkpoint manager to save model parameters, training step, random number generator state, and metadata to the logging directory.

Parameters:
checkpoint_path#
dataloader_path#
metadata#
manager#
on_filtered_validation_epoch_end(eval_metrics, epoch_idx, step_idx)#

Saves the model at the end of the validation epoch.

Parameters:
  • eval_metrics – Dictionary of evaluation metrics. If a monitored metric is set, the model is saved based on the monitored metrics in this dictionary. If the monitored metric is not found, an error is raised. The metrics are saved along with the model.

  • epoch_idx (int) – Index of the current epoch.

  • step_idx (int) – Index of the current step.

save_model(eval_metrics, step_idx)#

Saves model state dict to the logging directory.

Parameters:
  • eval_metrics (xlstm_jax.common_types.Metrics) – Dictionary of evaluation metrics. If a monitored metric is set, the model is saved based on the monitored metrics in this dictionary. If the monitored metric is not found, an error is raised. The metrics are saved along with the model.

  • step_idx (int) – Index of the current step.

save_dataloader(step_idx)#

Saves the dataloader state to the logging directory.

Parameters:

step_idx (int) – Index of the current step.

load_model(step_idx=-1, load_best=False, delete_params_before_loading=False)#

Loads model parameters and variables from the logging directory.

Parameters:
  • step_idx (int) – Index of the step to load. If -1, loads the latest step by default.

  • load_best (bool) – If True and step_idx is -1, loads the best checkpoint based on the monitored metric instead of the latest checkpoint.

  • delete_params_before_loading (bool) – If True, deletes the current parameters in the trainer state before loading the new parameters.

Returns:

Dictionary of loaded model parameters and additional variables.

Return type:

dict[str, Any]

load_dataloader(step_idx=-1, load_best=False)#

Loads the dataloader state from the logging directory.

Parameters:
  • step_idx (int) – Index of the step to load. If -1, loads the latest step by default.

  • load_best (bool) – If True and step_idx is -1, loads the best checkpoint based on the monitored metric instead of the latest checkpoint.

Returns:

Dictionary of loaded dataloader states.

Return type:

dict[str, Any]

resolve_step_idx(step_idx, load_best)#

Resolves the step index to load.

Parameters:
  • step_idx (int) – Index of the step to load. If -1, loads the latest step by default.

  • load_best (bool) – If True and step_idx is -1, loads the best checkpoint based on the monitored metric instead of the latest checkpoint.

Returns:

The resolved step index.

Return type:

int

finalize(status=None)#

Closes the checkpoint manager.

Parameters:

status (str | None) – The status of the training run (e.g. success, failure).

xlstm_jax.trainer.callbacks.checkpointing.load_pretrained_model(checkpoint_path, trainer, step_idx=-1, load_optimizer=True, load_best=False, delete_params_before_loading=False)#

Loads a pretrained model from a checkpoint.

Parameters:
  • checkpoint_path (pathlib.Path) – Path to the checkpoint directory.

  • trainer (Any) – Trainer object.

  • step_idx (int) – Index of the step to load. If -1, loads the latest step by default.

  • load_optimizer (bool) – If True the optimizer state is loaded from the checkpoint.

  • load_best (bool) – If True and step_idx is -1, loads the best checkpoint based on the monitored metric instead of the latest checkpoint.

  • delete_params_before_loading (bool) – If True, deletes the current parameters in the trainer state before loading the new parameters.

Returns:

Dictionary of loaded model parameters and additional variables, as well as the dataloader state and the resolved step index that was loaded.

Return type:

tuple[dict[str, Any], dict[str, Any], int]