xlstm_jax.trainer.callbacks.lr_monitor#
Attributes#
Classes#
Configuration for the LearningRateMonitor callback. |
|
Callback to monitor the learning rate. |
Module Contents#
- xlstm_jax.trainer.callbacks.lr_monitor.LOGGER#
- class xlstm_jax.trainer.callbacks.lr_monitor.LearningRateMonitorConfig#
Bases:
xlstm_jax.trainer.callbacks.callback.CallbackConfigConfiguration for the LearningRateMonitor callback.
- create(trainer, data_module=None)#
Creates the LearningRateMonitor callback.
- Parameters:
trainer (Any) – Trainer object.
data_module (optional) – Data module object.
- Returns:
LearningRateMonitor object.
- Return type:
- class xlstm_jax.trainer.callbacks.lr_monitor.LearningRateMonitor(config, trainer, data_module=None)#
Bases:
xlstm_jax.trainer.callbacks.callback.CallbackCallback to monitor the learning rate.
- Parameters:
config (LearningRateMonitorConfig)
trainer (Any)
data_module (xlstm_jax.trainer.data_module.DataloaderModule | None)
- lr_scheduler#
- on_filtered_training_step(step_metrics, epoch_idx, step_idx)#
Logs the learning rate after a step.