lmeval_extended_evaluation#

Attributes#

Classes#

LMEvalEvaluationConfig

LMEvalEvaluation

LMEvalEvaluation Callback

Functions#

log_info(msg)

Logs an info message on the host device.

fuse_document_results(results_dict)

Fuse log-likelihood results of capped sequences (max_length) to document log-likelihoods.

Module Contents#

lmeval_extended_evaluation.LOGGER#
lmeval_extended_evaluation.log_info(msg)#

Logs an info message on the host device.

Parameters:

msg (str) – Message to be logged.

lmeval_extended_evaluation.fuse_document_results(results_dict)#

Fuse log-likelihood results of capped sequences (max_length) to document log-likelihoods.

Aggregate results from (potentially) many batches of (potentially) many different tasks. All results must have matching document indices and aggregate log-likelihoods over multiple sequence indices weighted by their counts.

If the exact sequence is the result of a greedy decoding (i.e. if all single token accuracies of non-masked parts are 1), also aggregate “greedy” accuracies. :param results_dict: Dictionary of all results in concatenated form.

Returns:

Log-likelihoods and greedy (boolean) Accuracy for documents ordered by index.

Parameters:

results_dict (dict[str, numpy.ndarray | tuple[numpy.ndarray, numpy.ndarray]])

Return type:

list[tuple[float, bool]]

class lmeval_extended_evaluation.LMEvalEvaluationConfig#

Bases: xlstm_jax.trainer.callbacks.extended_evaluation.ExtendedEvaluationConfig

tokenizer_path: str#

Tokenizer path

evaluation_tasks: list[str]#

List of evaluation task from LM Evaluation Harness

cache_requests: bool = True#

Whether to cache requests

limit_requests: int | None = None#

Whether to limit requests to a smaller number for debugging purposes

write_out: bool = False#

Whether to write out results

use_infinite_eval: bool = True#

Whether to use the infinite eval

infinite_eval_chunksize: int = 64#

The chunk size for using the infinite eval

context_length: int | None = None#

Override context_length of the model

batch_size: int | None = None#

Override batch_size of the trainer

worker_buffer_size: int = 1#

Worker buffer size for the grain loader

worker_count: int = 0#

Number of workers for the grain loading

debug: bool = False#

Scale ouputs such that metrics can be computed for a random testing model

system_instruction: str | None = None#

Additional system instruction

num_fewshot: int | None = None#

Define number of in-context samples for few-shot training.

bootstrap_iters: int = 1000#

Bootstrap iterations for calculating stderrs on metrics - LMEval standard is 100000, limit here for speed

apply_chat_template: bool | str = False#

Apply the LMEval chat template, or a custom template

create(trainer, data_module=None)#
Parameters:
Returns:

LMEvalEvaluation object

Return type:

LMEvalEvaluation

class lmeval_extended_evaluation.LMEvalEvaluation(config, trainer, data_module=None)#

Bases: xlstm_jax.trainer.callbacks.extended_evaluation.ExtendedEvaluation

LMEvalEvaluation Callback

Parameters:
context_length#
batch_size#
lm#
create_modified_exemplary_batch(exmp_batch)#

Create an LLMIndexedBatch from a LLMBatch (for compilation purposes as example).

Parameters:

exmp_batch (xlstm_jax.dataset.batch.LLMBatch)

Returns:

LLMIndexedBatch

Return type:

xlstm_jax.dataset.batch.LLMIndexedBatch

run_evaluate()#

Runs the evaluation in LM Eval Harness. Does use external datasets. Might be called from callback functions to get metrics during training.

Returns:

Results from LMEval evaluation.

Return type:

xlstm_jax.common_types.HostMetrics

static get_metric_postprocess_fn()#

Get function to post-process metrics with on host.

Will be passed to logger. Adds perplexity to the metrics.

Returns:

The postprocess metric function.

Return type:

collections.abc.Callable[[xlstm_jax.common_types.HostMetrics], xlstm_jax.common_types.HostMetrics]

create_jitted_functions()#

Create jitted version of the evaluation function.

init_eval_metrics(batch=None)#

Override parent init_eval_metrics potentially for infinite eval. Then metrics are partly aggregated one level below (along the sequence) and aggregated fully (across batches) within eval_model.

Parameters:

batch (xlstm_jax.dataset.batch.LLMIndexedBatch | None) – An input to the model with which the shapes are inferred. If None, the exmp_batch is used.

Returns:

A dictionary of metrics with the same shape as the eval metrics.

Return type:

flax.core.FrozenDict | dict

aggregate_metrics(aggregated_metrics, eval_metrics)#

Aggregate metrics over multiple batches. This is an adaption of the parent class that ignores the passed “step_metrics” in the metrics dictionary. The “step_metrics” are single recurrent step metrics to be donated for a future evaluation step.

Parameters:
  • aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Old aggregated metrics

  • eval_metrics (xlstm_jax.common_types.HostMetrics) – Single batch metrics

Returns:

aggregated_metrics including the new batch

Return type:

xlstm_jax.common_types.HostMetrics

create_recurrent_evaluation_step_function(chunk_size, exmp_batch=None, cache_init_fn=None)#

Create and return a recurrent function for the evaluation step. (see also llm/trainer.py).

Compared to the create_evaluation_step_function, this evaluation supports much longer sequences by chunking the input and running the model recurrently over the chunks. This is useful for evaluation on long documents. This is enabled by keeping a cache, which is forwarded between evaluation steps.

Note: do not jit this function if you want to support arbitrary input shapes. This function jit’s the recurrent function for a single chunk, and adds a python loop around it to handle arbitrary length sequences. Thus, no outer jit is needed.

Note: this function is explicitly meant for recurrent models like xLSTM. Using this function on a non-recurrent model will lead to unexpected, incorrect results.

Parameters:
  • chunk_size (int) – Size of the chunks to split the input into. The slices are performed over the sequence length.

  • exmp_batch (xlstm_jax.dataset.batch.LLMIndexedBatch | None) – An example batch to determine the shape of the cache. Defaults to None, in which case the example batch from the trainer is used.

  • cache_init_fn (collections.abc.Callable[[xlstm_jax.common_types.PyTree], xlstm_jax.common_types.PyTree] | None) – A function to initialize the cache. If not provided, the cache is initialized with zeros. The function should take the shape dtype struct of the cache as input and return the initialized cache.

Returns:

The evaluation step function with support for arbitrary length sequences.

Return type:

collections.abc.Callable[[xlstm_jax.common_types.TrainState, xlstm_jax.dataset.batch.LLMBatch, xlstm_jax.common_types.ImmutableMetrics | None], xlstm_jax.common_types.ImmutableMetrics]

finalize_metrics(aggregated_metrics)#

Calculate final metrics from aggregated_metrics. (i,e, mean=sum/count)

Parameters:

aggregated_metrics (xlstm_jax.common_types.HostMetrics) – Aggregated metrics over the whole epoch

Returns:

Final metrics that are to be reported / logged.

Return type:

xlstm_jax.common_types.HostMetrics

eval_function(params, apply_fn, batch, rng=None, mutable_variables=None)#

Function that passes the batch through the model and generates some extended metrics.

Parameters:
  • params (Any) – Model parameters.

  • apply_fn (Any) – Model functions.

  • batch (xlstm_jax.dataset.batch.LLMIndexedBatch) – LLMIndexedBatch that is passed through the model.

  • rng (jax.Array | None) – RNG for potential dropout.

  • mutable_variables (dict[str, Any] | None) – Mutable variables for the evaluation step function, e.g. the cache (recurrent state).

Returns:

Tuple with Metrics and MutableVariables.

Return type:

tuple[xlstm_jax.common_types.Metrics, xlstm_jax.common_types.PyTree]

on_filtered_validation_epoch_start(epoch_idx, step_idx)#

Runs evaluation on filtered validation epochs / steps.

Parameters:
  • epoch_idx (int) – Epoch Index

  • step_idx (int) – Step Index

on_test_epoch_start(epoch_idx)#

Runs evaluation on test_epoch.

Parameters:

epoch_idx (int) – Epoch index