xlstm_jax.trainer.callbacks.callback#
Classes#
Base configuration of a callback. |
|
Base class for callbacks. |
Module Contents#
- class xlstm_jax.trainer.callbacks.callback.CallbackConfig#
Bases:
xlstm_jax.configs.ConfigDictBase configuration of a callback.
- every_n_epochs: int = 1#
If the callback implements functions on a per-epoch basis (e.g. on_training_epoch_start, on_training_epoch_end, on_validation_epoch_start), this parameter specifies the frequency of calling these functions.
- every_n_steps: int = -1#
If the callback implements functions on a per-step basis (e.g. on_training_step), this parameter specifies the frequency of calling these functions.
- create(trainer, data_module=None)#
Creates the callback object.
- Parameters:
trainer (Any) – Trainer object.
data_module (optional) – Data module object.
- Returns:
Callback object.
- Return type:
- to_dict()#
Converts the config to a dictionary.
Helpful for saving to disk or logging.
- static from_dict(config_class, data, strict_classname_parsing=False, ignore_extensive_attributes=True, none_to_zero_for_ints=False)#
Utility for parsing dictionaries back into a nested dataclass structure, including arbitrary classes and types.
Currently, this is tailored towards the current logging system with the “hardly” invertible to_dict.
- Parameters:
config_class (type) – Typically a dataclass, but can be any other type as well If it is another type, the parser tries to create an object via config_class(**data) if data is a dictionary or config_class(data) else.
data (Any) – Typically a dictionary that contains attributes of the dataclass. Can be any other kind of data.
strict_classname_parsing (bool) – Parse class names strictly.
ignore_extensive_attributes (bool) – Ignore attributes that are not defined in the dataclass.
none_to_zero_for_ints (bool) – Convert None to 0 for integer types.
- Returns:
An object of type config_class that contains the data as attributes.
- Return type:
Any
- class xlstm_jax.trainer.callbacks.callback.Callback(config, trainer, data_module=None)#
Base class for callbacks.
Callbacks are used to perform additional actions during training, validation, and testing. We provide a set of predefined functions, that can be overridden by subclasses to implement custom behavior. The predefined functions are called at the beginning and end of training, validation, and testing, as well as at the beginning and end of each epoch or step.
Note: all counts of epoch index and step index are starting at 1 (i.e. the first epoch is 1 instead of 0).
- Parameters:
config (CallbackConfig) – Configuration dictionary.
trainer (Any) – Trainer object.
data_module (optional) – Data module object.
- config#
- trainer#
- data_module = None#
- _every_n_epochs#
- _every_n_steps#
- _main_process_only#
- _active_on_epochs#
- _active_on_steps#
- on_training_start()#
Called at the beginning of training.
- on_training_end()#
Called at the end of training.
- on_training_epoch_start(epoch_idx)#
Called at the beginning of each training epoch.
- Parameters:
epoch_idx (int) – Index of the current epoch.
- on_filtered_training_epoch_start(epoch_idx)#
Called at the beginning of each every_n_epochs training epoch. To be implemented by subclasses.
- Parameters:
epoch_idx (int) – Index of the current epoch.
- on_training_epoch_end(train_metrics, epoch_idx)#
Called at the end of each training epoch.
- Parameters:
train_metrics (xlstm_jax.common_types.Metrics) – Dictionary of training metrics of the current epoch.
epoch_idx (int) – Index of the current epoch.
- on_filtered_training_epoch_end(train_metrics, epoch_idx)#
Called at the end of each every_n_epochs training epoch. To be implemented by subclasses.
- Parameters:
train_metrics (xlstm_jax.common_types.Metrics) – Dictionary of training metrics of the current epoch.
epoch_idx (int) – Index of the current epoch.
- on_training_step(step_metrics, epoch_idx, step_idx)#
Called at the end of each training step.
- on_filtered_training_step(step_metrics, epoch_idx, step_idx)#
Called at the end of each every_n_steps training step. To be implemented by subclasses.
- on_validation_epoch_start(epoch_idx, step_idx)#
Called at the beginning of validation.
- on_filtered_validation_epoch_start(epoch_idx, step_idx)#
Called at the beginning of every_n_epochs validation. To be implemented by subclasses.
- on_validation_epoch_end(eval_metrics, epoch_idx, step_idx)#
Called at the end of each validation epoch.
- on_filtered_validation_epoch_end(eval_metrics, epoch_idx, step_idx)#
Called at the end of each every_n_epochs validation epoch. To be implemented by subclasses.
- on_test_epoch_start(epoch_idx)#
Called at the beginning of testing.
To be implemented by subclasses.
- Parameters:
epoch_idx (int) – Index of the current epoch.
- on_test_epoch_end(test_metrics, epoch_idx)#
Called at the end of each test epoch. To be implemented by subclasses.
- Parameters:
test_metrics (xlstm_jax.common_types.Metrics) – Dictionary of test metrics of the current epoch.
epoch_idx (int) – Index of the current epoch.
- set_dataset(data_module)#
Sets the data module.
- Parameters:
data_module (xlstm_jax.trainer.data_module.DataloaderModule) – Data module object.