xlstm_jax.dataset.lmeval_pipeline#

Attributes#

Classes#

ParseLMEval

Parses an LMEval request into a simple dictionary format with prefix and text.

CompleteLLMIndexedBatch

Grain Transform that uses an indexed dataset (with "idx") and fills it towards all

PadBatchDataset

Creates a dataset that has only full batches by adding padding elements.

PadSequenceInBatchDataset

Creates a dataset that has only full batches by padding elements.

SortedDataset

Creates a sorted dataset based on a key (applied to all items) and an existing dataset.

MultihostSortedRemapDataset

This implements an index re-shuffling for a SortedDataset.

Functions#

empty_llm_indexed_sample()

Generator for an empty llm_indexed sample that is used in paddings.

token_length(item)

Get the token length of a data item for sorting (grouping) the dataset.

_pad_batch_multiple(batch[, multiple_of, axis, ...])

Pads a list of arrays to a common length defined as a multiple of a pad_mulitple value, with a certain value.

lmeval_preprocessing_pipeline(dataloading_host_index, ...)

Create a mult-host dataloader for LMEval datasets for loglikelihood and

Module Contents#

xlstm_jax.dataset.lmeval_pipeline.LOGGER#
class xlstm_jax.dataset.lmeval_pipeline.ParseLMEval(request_name='req')#

Bases: grain.python.MapTransform

Parses an LMEval request into a simple dictionary format with prefix and text. If there is no prefix, it is simply an empty string.

Parameters:

request_name (str) – The key in the input dictionary which corresponds to the LMEval Request.

request_name = 'req'#
map(item)#

Maps a single request to a dictionary of prefix and text.

Parameters:
  • dict[str – LMEval request instance in a dictionary with the index

  • Instance] (int |) – LMEval request instance in a dictionary with the index

  • item (dict[str, int | lm_eval.api.instance.Instance])

Returns:

Resulting item dictionary

Return type:

dict[str, str | int]

>>> from xlstm_jax.utils.pytree_utils import pytree_diff
>>> pytree_diff(ParseLMEval().map(
...     {"idx": 1,
...      "req": Instance(
...         request_type="loglikelihood_rolling", doc={},
...         idx=0, arguments=("Prefix", "Main"))}),
...     {"idx": 1, "prefix": "Prefix", "text": "Main"})
class xlstm_jax.dataset.lmeval_pipeline.CompleteLLMIndexedBatch#

Bases: grain.python.MapTransform

Grain Transform that uses an indexed dataset (with “idx”) and fills it towards all components of a LLMIndexedBatch.

>>> from xlstm_jax.utils.pytree_utils import pytree_diff
>>> pytree_diff(
...     CompleteLLMIndexedBatch().map(
...         {"inputs": np.array([[1, 2]]), "targets": np.array([[1, 2]]), "idx": np.array(0)}),
...     {"inputs": np.array([[1, 2]]), "targets": np.array([[1, 2]]),
...      "document_idx": np.array([1]), "inputs_position": np.array([[0, 1]]),
...      "targets_position": np.array([[0, 1]]), "sequence_idx": np.array([0]),
...      "_document_borders": np.array([[False, False]])})
static map(item)#

Converts an incomplete dict to a dictionary with all components for an LLMIndexedBatch

Parameters:

item (dict[str, numpy.ndarray])

Return type:

dict[str, numpy.ndarray]

xlstm_jax.dataset.lmeval_pipeline.empty_llm_indexed_sample()#

Generator for an empty llm_indexed sample that is used in paddings. This creates just the data for a single sample not a full batch object.

Returns:

An empty / padding sample for an LLMIndexedBatch

xlstm_jax.dataset.lmeval_pipeline.token_length(item)#

Get the token length of a data item for sorting (grouping) the dataset.

Parameters:

item (dict[str, numpy.ndarray]) – A dataset item that contains an inputs element.

Returns:

Length of the inputs element.

Return type:

int

>>> token_length({"inputs": np.array([[1, 2, 3]])})
3
class xlstm_jax.dataset.lmeval_pipeline.PadBatchDataset(dataset, multiple_of, min_length, pad_elem)#

Bases: grain.python.MapDataset

Creates a dataset that has only full batches by adding padding elements.

Parameters:
  • dataset (grain.python.MapDataset) – The existing dataset.

  • multiple_of (int) – Global batch size to be padded towards.

  • min_length (int) – Minimum (padded) length/size of the dataset.

  • pad_elem (Any) – Empty element to be appended.

>>> PadBatchDataset(grain.MapDataset.source([3, 1, 2]), multiple_of=4, min_length=0, pad_elem=0)[3]
0
>>> len(PadBatchDataset(grain.MapDataset.source([3, 1, 2]), multiple_of=4, min_length=0, pad_elem=0))
4
dataset#
multiple_of#
min_length#
pad_elem#
xlstm_jax.dataset.lmeval_pipeline._pad_batch_multiple(batch, multiple_of=64, axis=1, pad_value=0, batch_size_pad=None)#

Pads a list of arrays to a common length defined as a multiple of a pad_mulitple value, with a certain value. Then the arrays are concatenated along axis 0.

Parameters:
  • batch (list[numpy.ndarray]) – A list of np.ndarrays to be padded and concatenated

  • multiple_of (int) – A number that the padded size should be a multiple of.

  • axis (int) – The axis to be padded, typically a sequence dimension.

  • pad_value (int | float) – The padding value, typically zero.

  • batch_size_pad (int | None)

Returns:

The concatenated, padded batch.

Return type:

numpy.ndarray

>>> np.allclose(
...     _pad_batch_multiple([np.array([[1, 2]]), np.array([[11, 12, 13,]])], multiple_of=4, axis=1),
...     np.array([[1, 2, 0, 0], [11, 12, 13, 0]]))
True
class xlstm_jax.dataset.lmeval_pipeline.PadSequenceInBatchDataset(dataset, batch_size, multiple_of=64, pad_value=0)#

Bases: grain.python.MapDataset

Creates a dataset that has only full batches by padding elements.

This pads single elements (no batches) enabling having distributed batches over more devices. Assumes a dataset that consists of a flat dictionary of arrays.

Parameters:
  • dataset (grain.python.MapDataset) – The existing dataset, items are assumed to be dicts of array.

  • batch_size (int) – The batch size to pad towards.

  • multiple_of (int) – A number that the padded size should be a multiple of.

  • pad_value (int) – Value to pad with.

>>> from xlstm_jax.utils.pytree_utils import pytree_diff
>>> pytree_diff(list(PadSequenceInBatchDataset(grain.MapDataset.source(
...     [{"a": np.array([[3, 1, 5]])},
...      {"a": np.array([[2, 4]])},
...      {"a": np.array([[2, 4, 3, 3]])},
...      {"a": np.array([[2, 4]])}]
... ), batch_size=2, multiple_of=3 )), [
...     {"a": np.array([[3, 1, 5]])},
...     {"a": np.array([[2, 4, 0]])},
...     {"a": np.array([[2, 4, 3, 3, 0, 0]])},
...     {"a": np.array([[2, 4, 0, 0, 0, 0]])}])
dataset#
batch_size#
multiple_of = 64#
pad_value = 0#
class xlstm_jax.dataset.lmeval_pipeline.SortedDataset(dataset, key, reverse=False)#

Bases: grain.python.MapDataset

Creates a sorted dataset based on a key (applied to all items) and an existing dataset.

Parameters:
  • dataset (grain.python.MapDataset) – The existing dataset

  • key (collections.abc.Callable) – Key Function to be applied for sorting

  • reverse (bool) – If the sorting should be ascending (False, default) or descending

>>> SortedDataset(grain.MapDataset.source([3, 1, 2]), lambda x: x)[0]
1
>>> SortedDataset(grain.MapDataset.source([3, 1, 2]), lambda x: x, reverse=True)[0]
3
>>> SortedDataset(grain.MapDataset.source(
...     [{"inputs": np.array([[1, 2, 3]])},
...      {"inputs": np.array([[1, 2]])}]), token_length)[0]
{'inputs': array([[1, 2]])}
key#
dataset#
class xlstm_jax.dataset.lmeval_pipeline.MultihostSortedRemapDataset(dataset, global_batch_size, dataloader_host_count)#

Bases: grain.python.MapDataset

This implements an index re-shuffling for a SortedDataset. The problem: Given a sorted dataset, and multi-host dataloaders using .slice, the sorting is broken. Examplary dataset: [1 2 3 4 5 6 7 8]

Multi-host (standard slicing - assumed as input): [1 2 3 4] [5 6 7 8] Multi-host batched: [[1 2] [3 4]] [[5 6] [7 8]]

What we want actually for proper batching: [[1 2] [5 6]] [[3 4] [7 8]] such that the global batch still looks like: [[1 2 3 4] [5 6 7 8]]

Parameters:
  • dataset (grain.python.MapDataset) – Original (sorted) dataset of which the order within batches should be kept.

  • global_batch_size (int) – The global batch size.

  • dataloader_host_count (int) – The number of dataloaders that a global batch is created from.

>>> ds = MultihostSortedRemapDataset(
...     grain.MapDataset.source([1, 2, 3, 4, 5, 6, 7, 8]),
...     global_batch_size=4, dataloader_host_count=2)
>>> host_slices = [slice(0, 4), slice(4, 8)]
>>> [list(ds.slice(host_slices[0]).batch(2)), list(ds.slice(host_slices[1]).batch(2))]
[[array([1, 2]), array([5, 6])], [array([3, 4]), array([7, 8])]]
dataset#
global_batch_size#
dataloader_host_count#
xlstm_jax.dataset.lmeval_pipeline.lmeval_preprocessing_pipeline(dataloading_host_index, dataloading_host_count, global_mesh, dataset, global_batch_size, tokenizer_path, hf_access_token=None, tokenizer_cache_dir=None, eos_token_id=None, bos_token_id=None, worker_count=1, worker_buffer_size=1, padding_multiple=128, use_thread_prefetch=False)#

Create a mult-host dataloader for LMEval datasets for loglikelihood and loglikelihood_rolling tasks. This does not support generation tasks currently. Also, it just support recurrent models that can take infinite sequence lengths. For sequence_length limited models use the HFTokenizeLogLikelihoodRolling from lmeval_dataset.py.

Parameters:
  • dataloading_host_index (int) – The index of the dataloading host. Will be used to select the correct shard of the dataset. In JAX, this is equivalent to jax.process_index().

  • dataloading_host_count (int) – The number of dataloading hosts. Will be used to determine the shard size. In JAX, this is equivalent to jax.process_count().

  • global_mesh (jax.sharding.Mesh) – The global mesh to shard the data over.

  • dataset (list[lm_eval.api.instance.Instance]) – The dataset to load. Should provide a __getitem__ method to access elements.

  • global_batch_size (int) – The global batch size.

  • tokenizer_path (str) – Path to the tokenizer.

  • hf_access_token (str | None) – The access token for HuggingFace.

  • tokenizer_cache_dir (str | None) – The cache directory for the tokenizer.

  • eos_token_id (int | None) – The token ID to use for the end-of-sequence token. If tokenizer_path is provided, the tokenizer’s EOS token ID is used.

  • bos_token_id (int | None) – The token ID to use for the beginning-of-sequence token. If tokenizer_path is provided, the tokenizer’s BOS token ID is used.

  • worker_count (int) – The number of workers to use. In grain, a single worker is usually sufficient, as the data loading is done in parallel across hosts.

  • worker_buffer_size (int) – The buffer size for the workers.

  • padding_multiple (int) – Pad to size being a multiple of.

  • use_thread_prefetch (bool) – Use thread prefetching instead of multiprocessing.

Returns:

MultiHostDataLoadIterator for the lmeval dataset.

Return type:

xlstm_jax.dataset.multihost_dataloading.MultiHostDataLoadIterator