xlstm_jax.trainer.logger.base_logger#

Attributes#

Classes#

LoggerConfig

Configuration for the logger.

LoggerToolsConfig

Base config class for logger tools.

Logger

Logger class to log metrics, images, etc.

LoggerTool

Base class for logger tools.

Module Contents#

xlstm_jax.trainer.logger.base_logger.LOGGER#
class xlstm_jax.trainer.logger.base_logger.LoggerConfig#

Bases: xlstm_jax.configs.ConfigDict

Configuration for the logger.

log_every_n_steps#

The frequency at which logs should be written.

log_path#

The path where the logs should be written. If None, we will not write logs to disk.

log_tools#

A list of LoggerToolsConfig objects that should be used to log the metrics. These tools will be created in the Logger class.

cmd_logging_name#

The name of the output file for command line logging without suffix. The suffix .log will be added automatically.

log_every_n_steps: int = 1#
log_path: pathlib.Path | None = None#
log_tools: list[LoggerToolsConfig] = []#
cmd_logging_name: str = 'output'#
property log_dir: str#

Returns the log directory as a string.

Return type:

str

get(key, default=None)#
Parameters:

key (str)

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

static from_dict(config_class, data, strict_classname_parsing=False, ignore_extensive_attributes=True, none_to_zero_for_ints=False)#

Utility for parsing dictionaries back into a nested dataclass structure, including arbitrary classes and types.

Currently, this is tailored towards the current logging system with the “hardly” invertible to_dict.

Parameters:
  • config_class (type) – Typically a dataclass, but can be any other type as well If it is another type, the parser tries to create an object via config_class(**data) if data is a dictionary or config_class(data) else.

  • data (Any) – Typically a dictionary that contains attributes of the dataclass. Can be any other kind of data.

  • strict_classname_parsing (bool) – Parse class names strictly.

  • ignore_extensive_attributes (bool) – Ignore attributes that are not defined in the dataclass.

  • none_to_zero_for_ints (bool) – Convert None to 0 for integer types.

Returns:

An object of type config_class that contains the data as attributes.

Return type:

Any

class xlstm_jax.trainer.logger.base_logger.LoggerToolsConfig#

Bases: xlstm_jax.configs.ConfigDict

Base config class for logger tools.

These are tools that can be used to log metrics, images, etc. They are created inside the Logger class.

abstract create(logger)#

Creates the logger tool.

Parameters:

logger (Logger)

Return type:

LoggerTool

get(key, default=None)#
Parameters:

key (str)

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

static from_dict(config_class, data, strict_classname_parsing=False, ignore_extensive_attributes=True, none_to_zero_for_ints=False)#

Utility for parsing dictionaries back into a nested dataclass structure, including arbitrary classes and types.

Currently, this is tailored towards the current logging system with the “hardly” invertible to_dict.

Parameters:
  • config_class (type) – Typically a dataclass, but can be any other type as well If it is another type, the parser tries to create an object via config_class(**data) if data is a dictionary or config_class(data) else.

  • data (Any) – Typically a dictionary that contains attributes of the dataclass. Can be any other kind of data.

  • strict_classname_parsing (bool) – Parse class names strictly.

  • ignore_extensive_attributes (bool) – Ignore attributes that are not defined in the dataclass.

  • none_to_zero_for_ints (bool) – Convert None to 0 for integer types.

Returns:

An object of type config_class that contains the data as attributes.

Return type:

Any

class xlstm_jax.trainer.logger.base_logger.Logger(config, metric_postprocess_fn=None)#

Logger class to log metrics, images, etc.

Parameters:
config#
log_path#
metric_postprocess_fn = None#
epoch = 0#
step = 0#
found_nans = False#
last_step#
last_step_time = None#
epoch_start_time_stack = []#
mode_stack = []#
property mode: Literal['default', 'train', 'val', 'test']#

Current logging mode. Can be “default”, “train”, “val”, or “test”.

Returns:

The current logging mode.

Return type:

str

log_config(config)#

Logs the configuration.

Parameters:

config (xlstm_jax.configs.ConfigDict | dict[str, xlstm_jax.configs.ConfigDict]) – The configuration to log. Can also be a dictionary of multiple configurations.

on_training_start()#

Set up the logger for training.

start_epoch(epoch, step, mode='train')#

Starts a new epoch.

To be called before starting a new training, eval or test epoch. Can also be called if one is still in another epoch. For instance, if the training epoch is interrupted by a validation epoch, the logger switches to the validation mode until a end_epoch is called. Then, the logger switches back to the training mode.

Parameters:
  • epoch (int) – The index of the epoch.

  • step (int) – The index of the global training step.

  • mode (Literal['train', 'val', 'test']) – The logging mode. Should be in {“train”, “val”, “test”}. Defaults to “train”.

log_step(metrics, step)#

Log metrics for a single step.

Parameters:
  • metrics (xlstm_jax.common_types.Metrics) – The metrics to log. Should follow the structure of the metrics in the metrics.py file.

  • step (int) – The current step.

Returns:

If the metrics are logged in this step, the metrics will be updated to reset all metrics. If the metrics are not logged in this step, the metrics will be returned unchanged.

Return type:

xlstm_jax.common_types.Metrics

_check_for_nans(host_metrics, step=None)#

Check if any of the metrics contain NaNs.

If NaN are found, a warning is logged and the found_nans attribute is set to True.

Parameters:
  • host_metrics (xlstm_jax.common_types.HostMetrics) – The metrics to check.

  • step (int | None) – The step at which the metrics were logged. Used for logging if provided.

log_host_metrics(host_metrics, step, mode=None)#

Logs a dictionary of metrics on the host.

Can be used by callbacks to log additional metrics.

Parameters:
  • host_metrics (xlstm_jax.common_types.HostMetrics) – The metrics to log.

  • step (int) – The current step.

  • mode (str | None) – The mode / prefix with which to log the metrics. If None, the current mode is used.

end_epoch(metrics, step)#

Ends the current epoch and logs the epoch metrics.

If any other epoch is still running, the logger will switch back to that epoch.

Parameters:
  • metrics (xlstm_jax.common_types.Metrics) – The metrics that should be logged in this epoch.

  • step (int) – The current step.

Returns:

The originally passed metric dict and potentially any other metrics that should be passed to callbacks later on. Note that the metrics will not be reset.

Return type:

tuple[xlstm_jax.common_types.Metrics, xlstm_jax.common_types.HostMetrics]

finalize(status)#

Closes the logger.

Parameters:

status (str) – The status of the training run (e.g. success, failure).

class xlstm_jax.trainer.logger.base_logger.LoggerTool#

Base class for logger tools.

log_config(config)#

Log the configuration to the tool.

Parameters:

config (xlstm_jax.configs.ConfigDict | dict[str, xlstm_jax.configs.ConfigDict]) – The configuration to log.

abstract log_metrics(metrics, step, epoch, mode)#

Log the metrics to the tool.

Parameters:
  • metrics (xlstm_jax.common_types.HostMetrics) – The metrics to log.

  • step (int) – The current step.

  • epoch (int) – The current epoch.

  • mode (str) – The current mode (train, val, test).

finalize(status)#

Finalize and close the tool.

Parameters:

status (str)