xlstm_jax.trainer.metrics#
Attributes#
Functions#
|
Update metrics with new values. |
|
This function aggregates multiple metrics for the single_noreduce and single_noreduce_wcount case. |
|
|
|
Update a single metric. |
|
Calculates metrics to log from global metrics. |
Module Contents#
- xlstm_jax.trainer.metrics.LOGGER#
- xlstm_jax.trainer.metrics.update_metrics(global_metrics, step_metrics, default_log_modes=None)#
Update metrics with new values.
- Parameters:
global_metrics (xlstm_jax.common_types.Metrics | None) – Global metrics to update. If None, a new dictionary is created.
step_metrics (xlstm_jax.common_types.StepMetrics) – Metrics to update with.
default_log_modes (collections.abc.Sequence[xlstm_jax.common_types.LogMode] | None) – The default log mode for the metrics. If None, only the mean will be logged. Otherwise, we log each of the modes specified. The metric key will be appended with the log mode.
- Returns:
Updated global metrics.
- Return type:
xlstm_jax.common_types.ImmutableMetrics
- xlstm_jax.trainer.metrics.aggregate_metrics(aggregated_metrics, batch_metrics)#
This function aggregates multiple metrics for the single_noreduce and single_noreduce_wcount case. For single_noreduce batches of single values are concatenated. The count is the number of samples. For single_noreduce_wcount batches of values are concatenated, as well as counts for each sample. This is needed for e.g. the loglikelihood per sequence. Concatenation happens in CPU memory after a conversion.
The function returns batch_metrics in all other cases and moves them to CPU memory.
- Parameters:
aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Previously aggregated metrics to append to potentially.
batch_metrics (xlstm_jax.common_types.ImmutableMetrics) – Metrics from a batch
- Returns:
Newly aggregated metrics.
- Return type:
xlstm_jax.common_types.HostMetrics
- xlstm_jax.trainer.metrics._empty_val(value)#
- Parameters:
value (Any)
- Return type:
Any
- xlstm_jax.trainer.metrics._update_single_metric(global_metrics, key, value, count, log_modes=None)#
Update a single metric.
- Parameters:
global_metrics (xlstm_jax.common_types.MutableMetrics) – Global metrics to update.
key (str) – Key of the metric to update.
value (Any) – Value of the metric to update.
count (Any) – Count of the metric to update.
log_modes (collections.abc.Sequence[xlstm_jax.common_types.LogMode] | None) – The log modes for the metric.
- Return type:
xlstm_jax.common_types.MutableMetrics
- xlstm_jax.trainer.metrics.get_metrics(global_metrics, reset_metrics=True)#
Calculates metrics to log from global metrics.
Supports resetting the global metrics after logging. For example, if the global metrics are logged every epoch, the global metrics can be reset after obtaining the metrics to log such that the next epoch starts with empty metrics.
- Parameters:
global_metrics (xlstm_jax.common_types.Metrics) – Global metrics to log.
reset_metrics (bool) – Whether to reset the metrics after logging.
- Returns:
The updated global metrics if reset_metrics is True, otherwise the original global metrics. Additionally, the metrics to log on the host device are returned.
- Return type:
tuple[xlstm_jax.common_types.ImmutableMetrics, xlstm_jax.common_types.HostMetrics]