xlstm_jax.trainer.callbacks.lr_monitor#

Attributes#

Classes#

LearningRateMonitorConfig

Configuration for the LearningRateMonitor callback.

LearningRateMonitor

Callback to monitor the learning rate.

Module Contents#

xlstm_jax.trainer.callbacks.lr_monitor.LOGGER#
class xlstm_jax.trainer.callbacks.lr_monitor.LearningRateMonitorConfig#

Bases: xlstm_jax.trainer.callbacks.callback.CallbackConfig

Configuration for the LearningRateMonitor callback.

every_n_epochs: int = -1#

Log the learning rate every n epochs. Set to -1 to disable.

every_n_steps: int = 50#

Log the learning rate every n steps. By default, logs every 50 steps.

main_process_only: bool = True#

Log the learning rate only in the main process.

log_lr_key: str = 'optimizer/lr'#

Key to use for logging the learning rate.

create(trainer, data_module=None)#

Creates the LearningRateMonitor callback.

Parameters:
  • trainer (Any) – Trainer object.

  • data_module (optional) – Data module object.

Returns:

LearningRateMonitor object.

Return type:

LearningRateMonitor

class xlstm_jax.trainer.callbacks.lr_monitor.LearningRateMonitor(config, trainer, data_module=None)#

Bases: xlstm_jax.trainer.callbacks.callback.Callback

Callback to monitor the learning rate.

Parameters:
lr_scheduler#
on_filtered_training_step(step_metrics, epoch_idx, step_idx)#

Logs the learning rate after a step.

Parameters:
  • step_metrics (xlstm_jax.common_types.Metrics) – Metrics of the current step. Unused in this callback.

  • epoch_idx (int) – Index of the current epoch. Unused in this callback.

  • step_idx (int) – Index of the current step.

on_filtered_training_epoch_end(train_metrics, epoch_idx)#

Logs the learning rate after an epoch.

Parameters:
  • train_metrics (xlstm_jax.common_types.Metrics) – Metrics of the current epoch. Unused in this callback.

  • epoch_idx (int) – Index of the current epoch. Unused in this callback.

_log_lr(step_idx)#

Logs the learning rate.

Parameters:

step_idx (int) – Index of the current step.