xlstm_jax.trainer.callbacks.profiler#

Attributes#

Classes#

JaxProfilerConfig

Configuration for the JaxProfiler callback.

JaxProfiler

Callback to profile model training steps.

Module Contents#

xlstm_jax.trainer.callbacks.profiler.LOGGER#
class xlstm_jax.trainer.callbacks.profiler.JaxProfilerConfig#

Bases: xlstm_jax.trainer.callbacks.callback.CallbackConfig

Configuration for the JaxProfiler callback.

every_n_epochs#

Unused in this callback.

every_n_steps#

Unused in this callback.

main_process_only#

If True, the profiler is only active in the main process. Otherwise, one profile per process is created.

profile_every_n_minutes#

Profile every n minutes. If set below 0, the profiler is only done once at the beginning.

profile_first_step#

The first step to start profiling.

profile_n_steps#

Number of steps to profile.

profile_log_dir#

Directory to save the profiler logs. By default` set to “tensorboard”, where also the TensorBoard logs are saved.

every_n_epochs: int = -1#
every_n_steps: int = -1#
main_process_only: bool = True#
profile_every_n_minutes: int = 60#
profile_first_step: int = 10#
profile_n_steps: int = 5#
profile_log_dir: str = 'tensorboard'#
create(trainer, data_module=None)#

Creates the JaxProfiler callback.

Parameters:
  • trainer (Any) – Trainer object.

  • data_module (optional) – Data module object.

Returns:

JaxProfiler object.

Return type:

JaxProfiler

class xlstm_jax.trainer.callbacks.profiler.JaxProfiler(config, trainer, data_module=None)#

Bases: xlstm_jax.trainer.callbacks.callback.Callback

Callback to profile model training steps.

Parameters:
log_path: pathlib.Path#
profile_every_n_minutes#
profile_first_step#
profile_n_steps#
profiler_active = False#
profiler_last_time = None#
on_training_start()#

Called at the beginning of training.

Starts tracking the time to determine when to start the profiler.

on_training_step(step_metrics, epoch_idx, step_idx)#

Called at the end of each training step.

Starts the profiler if the current step is the first step or if the time since the last profiling is greater than the specified interval. If the profiler is active, it stops the profiler after the specified number of steps.

Parameters:
  • step_metrics (xlstm_jax.common_types.Metrics) – Dictionary of training metrics of the current step.

  • epoch_idx (int) – Index of the current epoch.

  • step_idx (int) – Index of the current step.

on_training_epoch_end(train_metrics, epoch_idx)#

Called at the end of each training epoch.

Stop the profiler if it is still active to prevent tracing non-training step operations.

Parameters:
  • train_metrics (xlstm_jax.common_types.Metrics) – Metrics of the current epoch.

  • epoch_idx (int) – Index of the current epoch.

on_validation_epoch_start(epoch_idx, step_idx)#

Called at the beginning of validation.

If profiler is active, stop it to prevent tracing all validation steps.

Parameters:
  • epoch_idx (int) – Index of the current training epoch.

  • step_idx (int) – Index of the current training step.

start_trace(step_idx)#

Start the profiler trace.

If the profiler is already active, a warning is logged.

Parameters:

step_idx (int) – Index of the current training step.

stop_trace()#

Stop the profiler trace.

If the profiler is not active, nothing is done.