xlstm_jax.trainer.logger.base_logger#
Attributes#
Classes#
Configuration for the logger. |
|
Base config class for logger tools. |
|
Logger class to log metrics, images, etc. |
|
Base class for logger tools. |
Module Contents#
- xlstm_jax.trainer.logger.base_logger.LOGGER#
- class xlstm_jax.trainer.logger.base_logger.LoggerConfig#
Bases:
xlstm_jax.configs.ConfigDictConfiguration for the logger.
- log_every_n_steps#
The frequency at which logs should be written.
- log_path#
The path where the logs should be written. If None, we will not write logs to disk.
- log_tools#
A list of LoggerToolsConfig objects that should be used to log the metrics. These tools will be created in the Logger class.
- cmd_logging_name#
The name of the output file for command line logging without suffix. The suffix .log will be added automatically.
- log_path: pathlib.Path | None = None#
- log_tools: list[LoggerToolsConfig] = []#
- to_dict()#
Converts the config to a dictionary.
Helpful for saving to disk or logging.
- static from_dict(config_class, data, strict_classname_parsing=False, ignore_extensive_attributes=True, none_to_zero_for_ints=False)#
Utility for parsing dictionaries back into a nested dataclass structure, including arbitrary classes and types.
Currently, this is tailored towards the current logging system with the “hardly” invertible to_dict.
- Parameters:
config_class (type) – Typically a dataclass, but can be any other type as well If it is another type, the parser tries to create an object via config_class(**data) if data is a dictionary or config_class(data) else.
data (Any) – Typically a dictionary that contains attributes of the dataclass. Can be any other kind of data.
strict_classname_parsing (bool) – Parse class names strictly.
ignore_extensive_attributes (bool) – Ignore attributes that are not defined in the dataclass.
none_to_zero_for_ints (bool) – Convert None to 0 for integer types.
- Returns:
An object of type config_class that contains the data as attributes.
- Return type:
Any
- class xlstm_jax.trainer.logger.base_logger.LoggerToolsConfig#
Bases:
xlstm_jax.configs.ConfigDictBase config class for logger tools.
These are tools that can be used to log metrics, images, etc. They are created inside the Logger class.
- to_dict()#
Converts the config to a dictionary.
Helpful for saving to disk or logging.
- static from_dict(config_class, data, strict_classname_parsing=False, ignore_extensive_attributes=True, none_to_zero_for_ints=False)#
Utility for parsing dictionaries back into a nested dataclass structure, including arbitrary classes and types.
Currently, this is tailored towards the current logging system with the “hardly” invertible to_dict.
- Parameters:
config_class (type) – Typically a dataclass, but can be any other type as well If it is another type, the parser tries to create an object via config_class(**data) if data is a dictionary or config_class(data) else.
data (Any) – Typically a dictionary that contains attributes of the dataclass. Can be any other kind of data.
strict_classname_parsing (bool) – Parse class names strictly.
ignore_extensive_attributes (bool) – Ignore attributes that are not defined in the dataclass.
none_to_zero_for_ints (bool) – Convert None to 0 for integer types.
- Returns:
An object of type config_class that contains the data as attributes.
- Return type:
Any
- class xlstm_jax.trainer.logger.base_logger.Logger(config, metric_postprocess_fn=None)#
Logger class to log metrics, images, etc.
- Parameters:
config (LoggerConfig)
metric_postprocess_fn (collections.abc.Callable[[xlstm_jax.common_types.HostMetrics], xlstm_jax.common_types.HostMetrics] | None)
- config#
- log_path#
- metric_postprocess_fn = None#
- epoch = 0#
- step = 0#
- found_nans = False#
- last_step#
- last_step_time = None#
- epoch_start_time_stack = []#
- mode_stack = []#
- property mode: Literal['default', 'train', 'val', 'test']#
Current logging mode. Can be “default”, “train”, “val”, or “test”.
- Returns:
The current logging mode.
- Return type:
- log_config(config)#
Logs the configuration.
- Parameters:
config (xlstm_jax.configs.ConfigDict | dict[str, xlstm_jax.configs.ConfigDict]) – The configuration to log. Can also be a dictionary of multiple configurations.
- on_training_start()#
Set up the logger for training.
- start_epoch(epoch, step, mode='train')#
Starts a new epoch.
To be called before starting a new training, eval or test epoch. Can also be called if one is still in another epoch. For instance, if the training epoch is interrupted by a validation epoch, the logger switches to the validation mode until a end_epoch is called. Then, the logger switches back to the training mode.
- log_step(metrics, step)#
Log metrics for a single step.
- Parameters:
metrics (xlstm_jax.common_types.Metrics) – The metrics to log. Should follow the structure of the metrics in the metrics.py file.
step (int) – The current step.
- Returns:
If the metrics are logged in this step, the metrics will be updated to reset all metrics. If the metrics are not logged in this step, the metrics will be returned unchanged.
- Return type:
xlstm_jax.common_types.Metrics
- _check_for_nans(host_metrics, step=None)#
Check if any of the metrics contain NaNs.
If NaN are found, a warning is logged and the found_nans attribute is set to True.
- Parameters:
host_metrics (xlstm_jax.common_types.HostMetrics) – The metrics to check.
step (int | None) – The step at which the metrics were logged. Used for logging if provided.
- log_host_metrics(host_metrics, step, mode=None)#
Logs a dictionary of metrics on the host.
Can be used by callbacks to log additional metrics.
- end_epoch(metrics, step)#
Ends the current epoch and logs the epoch metrics.
If any other epoch is still running, the logger will switch back to that epoch.
- Parameters:
metrics (xlstm_jax.common_types.Metrics) – The metrics that should be logged in this epoch.
step (int) – The current step.
- Returns:
The originally passed metric dict and potentially any other metrics that should be passed to callbacks later on. Note that the metrics will not be reset.
- Return type:
tuple[xlstm_jax.common_types.Metrics, xlstm_jax.common_types.HostMetrics]
- class xlstm_jax.trainer.logger.base_logger.LoggerTool#
Base class for logger tools.
- log_config(config)#
Log the configuration to the tool.
- Parameters:
config (xlstm_jax.configs.ConfigDict | dict[str, xlstm_jax.configs.ConfigDict]) – The configuration to log.
- abstract log_metrics(metrics, step, epoch, mode)#
Log the metrics to the tool.