xlstm_jax.trainer.callbacks.extended_evaluation#

Attributes#

Classes#

EvalState

EvalState with additional mutable variables and RNG.

ExtendedEvaluationConfig

Configuration for additional Evaluations callback.

ExtendedEvaluation

Callback that runs additional evaluations.

Functions#

device_metrics_aggregation(trainer, metrics)

Aggregates metrics beyond a single scalar value and a count.

Module Contents#

xlstm_jax.trainer.callbacks.extended_evaluation.LOGGER#
class xlstm_jax.trainer.callbacks.extended_evaluation.EvalState#

Bases: flax.struct.PyTreeNode

EvalState with additional mutable variables and RNG.

Parameters:
  • step – Counter starts at 0 and is incremented by every call to .apply_gradients().

  • apply_fn – Usually set to model.apply(). Kept in this dataclass for convenience to have a shorter params list for the train_step() function in your training loop.

  • params – The parameters to be updated by tx and used by apply_fn.

step: int | jax.Array#
rng: jax.Array#
apply_fn: collections.abc.Callable#
params: flax.core.FrozenDict[str, Any]#
mutable_variables: Any = None#
classmethod from_train_state(*, train_state)#
Parameters:

train_state (xlstm_jax.common_types.TrainState)

Return type:

EvalState

classmethod create(*, apply_fn, params, **kwargs)#
Parameters:
Return type:

EvalState

class xlstm_jax.trainer.callbacks.extended_evaluation.ExtendedEvaluationConfig#

Bases: xlstm_jax.trainer.callbacks.callback.CallbackConfig

Configuration for additional Evaluations callback.

create(trainer, data_module=None)#

Creates an Evaluation callback.

Parameters:
Return type:

ExtendedEvaluation

xlstm_jax.trainer.callbacks.extended_evaluation.device_metrics_aggregation(trainer, metrics)#

Aggregates metrics beyond a single scalar value and a count.

Also include single_noreduce metrics by concatenation.

Parameters:
  • trainer (Any) – Trainer (for aggregation axes)

  • metrics (xlstm_jax.common_types.Metrics) – the sharded metrics

Returns:

The reduced/gathered metrics.

Return type:

xlstm_jax.common_types.Metrics

class xlstm_jax.trainer.callbacks.extended_evaluation.ExtendedEvaluation(config, trainer, data_module=None)#

Bases: xlstm_jax.trainer.callbacks.callback.Callback

Callback that runs additional evaluations.

Parameters:
config#
trainer#
exmp_batch#
eval_step = None#
_eval_metric_shapes = None#
create_modified_exemplary_batch(exmp_batch)#

Create a modified exemplary batch for evaluation. Is useful for passing additional information / metadata to the batch for post-processing.

Parameters:

exmp_batch (xlstm_jax.dataset.batch.Batch) – “Original” training exemplary batch

Returns:

Modified exemplary batch for evaluation, might be the unmodified original.

Return type:

xlstm_jax.dataset.batch.Batch

eval_function(params, apply_fn, batch, rng)#

The extended evaluation function calculating metrics.

This function needs to be overwritten by a subclass.

Parameters:
  • params (Any) – The model parameters.

  • apply_fn (Any) – The apply function of the state.

  • batch (xlstm_jax.dataset.batch.Batch) – The current batch.

  • rng (jax.Array) – The random number generator.

Returns:

A tuple of metrics and mutable variables.

Return type:

tuple[xlstm_jax.common_types.Metrics, xlstm_jax.common_types.PyTree]

create_jitted_functions()#

Create jitted version of the evaluation function.

create_evaluation_step_function()#

Create and return a function for the extended evaluation step.

The function takes as input the training state and a batch from the val/test loader. The function is expected to return a dictionary of logging metrics and a new train state.

Returns:

Step function calculating metrics for one batch.

Return type:

collections.abc.Callable[[xlstm_jax.common_types.TrainState, xlstm_jax.dataset.batch.Batch, xlstm_jax.common_types.ImmutableMetrics | None], xlstm_jax.common_types.ImmutableMetrics]

init_eval_metrics(batch=None, alternative_eval_step=None)#

Initialize the evaluation metrics with zeros.

We infer the evaluation metric shape from the eval_step function. This is done to prevent a double-compilation of the eval_step function, where the first step has to be done with metrics None, and the next one with the metrics shape.

Parameters:
  • batch (xlstm_jax.dataset.batch.Batch | None) – An input to the model with which the shapes are inferred. If None, the exmp_batch is used.

  • alternative_eval_step (collections.abc.Callable | None) – An optional alternative eval step (not self.eval_step). This is needed if for a more complex eval step, that internally computes multiple steps (i.e. infinite eval in lmeval_extended_evaluation.py)

Returns:

A dictionary of metrics with the same shape as the eval metrics.

Return type:

flax.core.FrozenDict

aggregate_metrics(aggregated_metrics, eval_metrics)#

Aggregate metrics over multiple batches.

This is needed for “expensive” metrics that go beyond a scalar value and an accumulation count. These are then aggregated in CPU memory. The individual batch metrics might already be an actual aggregate for scalar values.

Parameters:
  • aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Old aggregated metrics

  • eval_metrics (xlstm_jax.common_types.ImmutableMetrics) – Single batch metrics

Returns:

aggregated_metrics including the new batch

Return type:

xlstm_jax.common_types.HostMetrics

eval_model(data_loader, mode='test', epoch_idx=0)#

Evaluate the model on a dataset.

Parameters:
  • data_loader (collections.abc.Iterator) – Data loader of the dataset to evaluate on.

  • mode (str) – Whether ‘val’ or ‘test’

  • epoch_idx (int) – Current epoch index.

Returns:

A dictionary of the evaluation metrics, averaged over data points in the dataset.

Return type:

xlstm_jax.common_types.HostMetrics

finalize_metrics(aggregated_metrics)#

Calculate final metrics from aggregated_metrics. (i,e, mean=sum/count)

Parameters:

aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Aggregated metrics over the whole epoch

Returns:

Final metrics that are to be reported / logged.

Return type:

xlstm_jax.common_types.HostMetrics

on_extended_evaluation_start()#

Callback for extended evaluation start.

on_extended_evaluation_end(final_metrics)#

Callback for extended evaluation end with final_metrics attached.

Parameters:

final_metrics (xlstm_jax.common_types.HostMetrics)