lmeval_extended_evaluation#
Attributes#
Classes#
LMEvalEvaluation Callback |
Functions#
|
Logs an info message on the host device. |
|
Fuse log-likelihood results of capped sequences (max_length) to document log-likelihoods. |
Module Contents#
- lmeval_extended_evaluation.LOGGER#
- lmeval_extended_evaluation.log_info(msg)#
Logs an info message on the host device.
- Parameters:
msg (str) – Message to be logged.
- lmeval_extended_evaluation.fuse_document_results(results_dict)#
Fuse log-likelihood results of capped sequences (max_length) to document log-likelihoods.
Aggregate results from (potentially) many batches of (potentially) many different tasks. All results must have matching document indices and aggregate log-likelihoods over multiple sequence indices weighted by their counts.
If the exact sequence is the result of a greedy decoding (i.e. if all single token accuracies of non-masked parts are 1), also aggregate “greedy” accuracies. :param results_dict: Dictionary of all results in concatenated form.
- Returns:
Log-likelihoods and greedy (boolean) Accuracy for documents ordered by index.
- Parameters:
results_dict (dict[str, numpy.ndarray | tuple[numpy.ndarray, numpy.ndarray]])
- Return type:
- class lmeval_extended_evaluation.LMEvalEvaluationConfig#
Bases:
xlstm_jax.trainer.callbacks.extended_evaluation.ExtendedEvaluationConfig- limit_requests: int | None = None#
Whether to limit requests to a smaller number for debugging purposes
- bootstrap_iters: int = 1000#
Bootstrap iterations for calculating stderrs on metrics - LMEval standard is 100000, limit here for speed
- create(trainer, data_module=None)#
- Parameters:
trainer (Any) – Trainer
data_module (xlstm_jax.trainer.data_module.DataloaderModule | None) – DataloaderModule containing train/val/test - not used here
- Returns:
LMEvalEvaluation object
- Return type:
- class lmeval_extended_evaluation.LMEvalEvaluation(config, trainer, data_module=None)#
Bases:
xlstm_jax.trainer.callbacks.extended_evaluation.ExtendedEvaluationLMEvalEvaluation Callback
- Parameters:
config (LMEvalEvaluationConfig)
trainer (xlstm_jax.trainer.llm.trainer.LLMTrainer)
data_module (xlstm_jax.trainer.data_module.DataloaderModule | None)
- context_length#
- batch_size#
- lm#
- create_modified_exemplary_batch(exmp_batch)#
Create an LLMIndexedBatch from a LLMBatch (for compilation purposes as example).
- Parameters:
exmp_batch (xlstm_jax.dataset.batch.LLMBatch)
- Returns:
LLMIndexedBatch
- Return type:
- run_evaluate()#
Runs the evaluation in LM Eval Harness. Does use external datasets. Might be called from callback functions to get metrics during training.
- Returns:
Results from LMEval evaluation.
- Return type:
xlstm_jax.common_types.HostMetrics
- static get_metric_postprocess_fn()#
Get function to post-process metrics with on host.
Will be passed to logger. Adds perplexity to the metrics.
- Returns:
The postprocess metric function.
- Return type:
collections.abc.Callable[[xlstm_jax.common_types.HostMetrics], xlstm_jax.common_types.HostMetrics]
- create_jitted_functions()#
Create jitted version of the evaluation function.
- init_eval_metrics(batch=None)#
Override parent init_eval_metrics potentially for infinite eval. Then metrics are partly aggregated one level below (along the sequence) and aggregated fully (across batches) within eval_model.
- Parameters:
batch (xlstm_jax.dataset.batch.LLMIndexedBatch | None) – An input to the model with which the shapes are inferred. If None, the
exmp_batchis used.- Returns:
A dictionary of metrics with the same shape as the eval metrics.
- Return type:
flax.core.FrozenDict | dict
- aggregate_metrics(aggregated_metrics, eval_metrics)#
Aggregate metrics over multiple batches. This is an adaption of the parent class that ignores the passed “step_metrics” in the metrics dictionary. The “step_metrics” are single recurrent step metrics to be donated for a future evaluation step.
- Parameters:
aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Old aggregated metrics
eval_metrics (xlstm_jax.common_types.HostMetrics) – Single batch metrics
- Returns:
aggregated_metrics including the new batch
- Return type:
xlstm_jax.common_types.HostMetrics
- create_recurrent_evaluation_step_function(chunk_size, exmp_batch=None, cache_init_fn=None)#
Create and return a recurrent function for the evaluation step. (see also llm/trainer.py).
Compared to the create_evaluation_step_function, this evaluation supports much longer sequences by chunking the input and running the model recurrently over the chunks. This is useful for evaluation on long documents. This is enabled by keeping a cache, which is forwarded between evaluation steps.
Note: do not jit this function if you want to support arbitrary input shapes. This function jit’s the recurrent function for a single chunk, and adds a python loop around it to handle arbitrary length sequences. Thus, no outer jit is needed.
Note: this function is explicitly meant for recurrent models like xLSTM. Using this function on a non-recurrent model will lead to unexpected, incorrect results.
- Parameters:
chunk_size (int) – Size of the chunks to split the input into. The slices are performed over the sequence length.
exmp_batch (xlstm_jax.dataset.batch.LLMIndexedBatch | None) – An example batch to determine the shape of the cache. Defaults to None, in which case the example batch from the trainer is used.
cache_init_fn (collections.abc.Callable[[xlstm_jax.common_types.PyTree], xlstm_jax.common_types.PyTree] | None) – A function to initialize the cache. If not provided, the cache is initialized with zeros. The function should take the shape dtype struct of the cache as input and return the initialized cache.
- Returns:
The evaluation step function with support for arbitrary length sequences.
- Return type:
collections.abc.Callable[[xlstm_jax.common_types.TrainState, xlstm_jax.dataset.batch.LLMBatch, xlstm_jax.common_types.ImmutableMetrics | None], xlstm_jax.common_types.ImmutableMetrics]
- finalize_metrics(aggregated_metrics)#
Calculate final metrics from aggregated_metrics. (i,e, mean=sum/count)
- Parameters:
aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Aggregated metrics over the whole epoch
- Returns:
Final metrics that are to be reported / logged.
- Return type:
xlstm_jax.common_types.HostMetrics
- eval_function(params, apply_fn, batch, rng=None, mutable_variables=None)#
Function that passes the batch through the model and generates some extended metrics.
- Parameters:
params (Any) – Model parameters.
apply_fn (Any) – Model functions.
batch (xlstm_jax.dataset.batch.LLMIndexedBatch) – LLMIndexedBatch that is passed through the model.
rng (jax.Array | None) – RNG for potential dropout.
mutable_variables (dict[str, Any] | None) – Mutable variables for the evaluation step function, e.g. the cache (recurrent state).
- Returns:
Tuple with Metrics and MutableVariables.
- Return type:
tuple[xlstm_jax.common_types.Metrics, xlstm_jax.common_types.PyTree]
- on_filtered_validation_epoch_start(epoch_idx, step_idx)#
Runs evaluation on filtered validation epochs / steps.