xlstm_jax.trainer.callbacks.extended_evaluation#
Attributes#
Classes#
EvalState with additional mutable variables and RNG. |
|
Configuration for additional Evaluations callback. |
|
Callback that runs additional evaluations. |
Functions#
|
Aggregates metrics beyond a single scalar value and a count. |
Module Contents#
- xlstm_jax.trainer.callbacks.extended_evaluation.LOGGER#
- class xlstm_jax.trainer.callbacks.extended_evaluation.EvalState#
Bases:
flax.struct.PyTreeNodeEvalState with additional mutable variables and RNG.
- Parameters:
step – Counter starts at 0 and is incremented by every call to
.apply_gradients().apply_fn – Usually set to
model.apply(). Kept in this dataclass for convenience to have a shorter params list for thetrain_step()function in your training loop.params – The parameters to be updated by
txand used byapply_fn.
- apply_fn: collections.abc.Callable#
- mutable_variables: Any = None#
- classmethod from_train_state(*, train_state)#
- Parameters:
train_state (xlstm_jax.common_types.TrainState)
- Return type:
- classmethod create(*, apply_fn, params, **kwargs)#
- Parameters:
apply_fn (collections.abc.Callable)
params (flax.core.FrozenDict[str, Any])
- Return type:
- class xlstm_jax.trainer.callbacks.extended_evaluation.ExtendedEvaluationConfig#
Bases:
xlstm_jax.trainer.callbacks.callback.CallbackConfigConfiguration for additional Evaluations callback.
- create(trainer, data_module=None)#
Creates an Evaluation callback.
- Parameters:
trainer (Any) – Trainer object.
data_module (xlstm_jax.trainer.data_module.DataloaderModule | None) – Data module object.
- Return type:
- xlstm_jax.trainer.callbacks.extended_evaluation.device_metrics_aggregation(trainer, metrics)#
Aggregates metrics beyond a single scalar value and a count.
Also include single_noreduce metrics by concatenation.
- Parameters:
trainer (Any) – Trainer (for aggregation axes)
metrics (xlstm_jax.common_types.Metrics) – the sharded metrics
- Returns:
The reduced/gathered metrics.
- Return type:
xlstm_jax.common_types.Metrics
- class xlstm_jax.trainer.callbacks.extended_evaluation.ExtendedEvaluation(config, trainer, data_module=None)#
Bases:
xlstm_jax.trainer.callbacks.callback.CallbackCallback that runs additional evaluations.
- Parameters:
config (ExtendedEvaluationConfig) – The configuration for the Evaluation callback.
trainer (Any) – Trainer
data_module (xlstm_jax.trainer.data_module.DataloaderModule | None) –
DataloaderModule, containing train/val/test data loaders.
- config#
- trainer#
- exmp_batch#
- eval_step = None#
- _eval_metric_shapes = None#
- create_modified_exemplary_batch(exmp_batch)#
Create a modified exemplary batch for evaluation. Is useful for passing additional information / metadata to the batch for post-processing.
- Parameters:
exmp_batch (xlstm_jax.dataset.batch.Batch) – “Original” training exemplary batch
- Returns:
Modified exemplary batch for evaluation, might be the unmodified original.
- Return type:
- eval_function(params, apply_fn, batch, rng)#
The extended evaluation function calculating metrics.
This function needs to be overwritten by a subclass.
- Parameters:
params (Any) – The model parameters.
apply_fn (Any) – The apply function of the state.
batch (xlstm_jax.dataset.batch.Batch) – The current batch.
rng (jax.Array) – The random number generator.
- Returns:
A tuple of metrics and mutable variables.
- Return type:
tuple[xlstm_jax.common_types.Metrics, xlstm_jax.common_types.PyTree]
- create_jitted_functions()#
Create jitted version of the evaluation function.
- create_evaluation_step_function()#
Create and return a function for the extended evaluation step.
The function takes as input the training state and a batch from the val/test loader. The function is expected to return a dictionary of logging metrics and a new train state.
- Returns:
Step function calculating metrics for one batch.
- Return type:
collections.abc.Callable[[xlstm_jax.common_types.TrainState, xlstm_jax.dataset.batch.Batch, xlstm_jax.common_types.ImmutableMetrics | None], xlstm_jax.common_types.ImmutableMetrics]
- init_eval_metrics(batch=None, alternative_eval_step=None)#
Initialize the evaluation metrics with zeros.
We infer the evaluation metric shape from the eval_step function. This is done to prevent a double-compilation of the eval_step function, where the first step has to be done with metrics None, and the next one with the metrics shape.
- Parameters:
batch (xlstm_jax.dataset.batch.Batch | None) – An input to the model with which the shapes are inferred. If None, the
exmp_batchis used.alternative_eval_step (collections.abc.Callable | None) – An optional alternative eval step (not self.eval_step). This is needed if for a more complex eval step, that internally computes multiple steps (i.e. infinite eval in lmeval_extended_evaluation.py)
- Returns:
A dictionary of metrics with the same shape as the eval metrics.
- Return type:
flax.core.FrozenDict
- aggregate_metrics(aggregated_metrics, eval_metrics)#
Aggregate metrics over multiple batches.
This is needed for “expensive” metrics that go beyond a scalar value and an accumulation count. These are then aggregated in CPU memory. The individual batch metrics might already be an actual aggregate for scalar values.
- Parameters:
aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Old aggregated metrics
eval_metrics (xlstm_jax.common_types.ImmutableMetrics) – Single batch metrics
- Returns:
aggregated_metrics including the new batch
- Return type:
xlstm_jax.common_types.HostMetrics
- eval_model(data_loader, mode='test', epoch_idx=0)#
Evaluate the model on a dataset.
- Parameters:
data_loader (collections.abc.Iterator) – Data loader of the dataset to evaluate on.
mode (str) – Whether ‘val’ or ‘test’
epoch_idx (int) – Current epoch index.
- Returns:
A dictionary of the evaluation metrics, averaged over data points in the dataset.
- Return type:
xlstm_jax.common_types.HostMetrics
- finalize_metrics(aggregated_metrics)#
Calculate final metrics from aggregated_metrics. (i,e, mean=sum/count)
- Parameters:
aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Aggregated metrics over the whole epoch
- Returns:
Final metrics that are to be reported / logged.
- Return type:
xlstm_jax.common_types.HostMetrics
- on_extended_evaluation_start()#
Callback for extended evaluation start.
- on_extended_evaluation_end(final_metrics)#
Callback for extended evaluation end with final_metrics attached.
- Parameters:
final_metrics (xlstm_jax.common_types.HostMetrics)