xlstm_jax.trainer.callbacks.profiler#
Attributes#
Classes#
Configuration for the JaxProfiler callback. |
|
Callback to profile model training steps. |
Module Contents#
- xlstm_jax.trainer.callbacks.profiler.LOGGER#
- class xlstm_jax.trainer.callbacks.profiler.JaxProfilerConfig#
Bases:
xlstm_jax.trainer.callbacks.callback.CallbackConfigConfiguration for the JaxProfiler callback.
- every_n_epochs#
Unused in this callback.
- every_n_steps#
Unused in this callback.
- main_process_only#
If True, the profiler is only active in the main process. Otherwise, one profile per process is created.
- profile_every_n_minutes#
Profile every n minutes. If set below 0, the profiler is only done once at the beginning.
- profile_first_step#
The first step to start profiling.
- profile_n_steps#
Number of steps to profile.
- profile_log_dir#
Directory to save the profiler logs. By default` set to “tensorboard”, where also the TensorBoard logs are saved.
- create(trainer, data_module=None)#
Creates the JaxProfiler callback.
- Parameters:
trainer (Any) – Trainer object.
data_module (optional) – Data module object.
- Returns:
JaxProfiler object.
- Return type:
- class xlstm_jax.trainer.callbacks.profiler.JaxProfiler(config, trainer, data_module=None)#
Bases:
xlstm_jax.trainer.callbacks.callback.CallbackCallback to profile model training steps.
- Parameters:
config (JaxProfilerConfig)
trainer (Any)
data_module (xlstm_jax.trainer.data_module.DataloaderModule | None)
- log_path: pathlib.Path#
- profile_every_n_minutes#
- profile_first_step#
- profile_n_steps#
- profiler_active = False#
- profiler_last_time = None#
- on_training_start()#
Called at the beginning of training.
Starts tracking the time to determine when to start the profiler.
- on_training_step(step_metrics, epoch_idx, step_idx)#
Called at the end of each training step.
Starts the profiler if the current step is the first step or if the time since the last profiling is greater than the specified interval. If the profiler is active, it stops the profiler after the specified number of steps.
- on_training_epoch_end(train_metrics, epoch_idx)#
Called at the end of each training epoch.
Stop the profiler if it is still active to prevent tracing non-training step operations.
- Parameters:
train_metrics (xlstm_jax.common_types.Metrics) – Metrics of the current epoch.
epoch_idx (int) – Index of the current epoch.
- on_validation_epoch_start(epoch_idx, step_idx)#
Called at the beginning of validation.
If profiler is active, stop it to prevent tracing all validation steps.
- start_trace(step_idx)#
Start the profiler trace.
If the profiler is already active, a warning is logged.
- Parameters:
step_idx (int) – Index of the current training step.
- stop_trace()#
Stop the profiler trace.
If the profiler is not active, nothing is done.