xlstm_jax.dataset.lmeval_pipeline#
Attributes#
Classes#
Parses an LMEval request into a simple dictionary format with prefix and text. |
|
Grain Transform that uses an indexed dataset (with "idx") and fills it towards all |
|
Creates a dataset that has only full batches by adding padding elements. |
|
Creates a dataset that has only full batches by padding elements. |
|
Creates a sorted dataset based on a key (applied to all items) and an existing dataset. |
|
This implements an index re-shuffling for a SortedDataset. |
Functions#
Generator for an empty llm_indexed sample that is used in paddings. |
|
|
Get the token length of a data item for sorting (grouping) the dataset. |
|
Pads a list of arrays to a common length defined as a multiple of a pad_mulitple value, with a certain value. |
|
Create a mult-host dataloader for LMEval datasets for loglikelihood and |
Module Contents#
- xlstm_jax.dataset.lmeval_pipeline.LOGGER#
- class xlstm_jax.dataset.lmeval_pipeline.ParseLMEval(request_name='req')#
Bases:
grain.python.MapTransformParses an LMEval request into a simple dictionary format with prefix and text. If there is no prefix, it is simply an empty string.
- Parameters:
request_name (str) – The key in the input dictionary which corresponds to the LMEval Request.
- request_name = 'req'#
- map(item)#
Maps a single request to a dictionary of prefix and text.
- Parameters:
- Returns:
Resulting item dictionary
- Return type:
>>> from xlstm_jax.utils.pytree_utils import pytree_diff >>> pytree_diff(ParseLMEval().map( ... {"idx": 1, ... "req": Instance( ... request_type="loglikelihood_rolling", doc={}, ... idx=0, arguments=("Prefix", "Main"))}), ... {"idx": 1, "prefix": "Prefix", "text": "Main"})
- class xlstm_jax.dataset.lmeval_pipeline.CompleteLLMIndexedBatch#
Bases:
grain.python.MapTransformGrain Transform that uses an indexed dataset (with “idx”) and fills it towards all components of a LLMIndexedBatch.
>>> from xlstm_jax.utils.pytree_utils import pytree_diff >>> pytree_diff( ... CompleteLLMIndexedBatch().map( ... {"inputs": np.array([[1, 2]]), "targets": np.array([[1, 2]]), "idx": np.array(0)}), ... {"inputs": np.array([[1, 2]]), "targets": np.array([[1, 2]]), ... "document_idx": np.array([1]), "inputs_position": np.array([[0, 1]]), ... "targets_position": np.array([[0, 1]]), "sequence_idx": np.array([0]), ... "_document_borders": np.array([[False, False]])})
- static map(item)#
Converts an incomplete dict to a dictionary with all components for an LLMIndexedBatch
- Parameters:
item (dict[str, numpy.ndarray])
- Return type:
- xlstm_jax.dataset.lmeval_pipeline.empty_llm_indexed_sample()#
Generator for an empty llm_indexed sample that is used in paddings. This creates just the data for a single sample not a full batch object.
- Returns:
An empty / padding sample for an LLMIndexedBatch
- xlstm_jax.dataset.lmeval_pipeline.token_length(item)#
Get the token length of a data item for sorting (grouping) the dataset.
- Parameters:
item (dict[str, numpy.ndarray]) – A dataset item that contains an inputs element.
- Returns:
Length of the inputs element.
- Return type:
>>> token_length({"inputs": np.array([[1, 2, 3]])}) 3
- class xlstm_jax.dataset.lmeval_pipeline.PadBatchDataset(dataset, multiple_of, min_length, pad_elem)#
Bases:
grain.python.MapDatasetCreates a dataset that has only full batches by adding padding elements.
- Parameters:
>>> PadBatchDataset(grain.MapDataset.source([3, 1, 2]), multiple_of=4, min_length=0, pad_elem=0)[3] 0 >>> len(PadBatchDataset(grain.MapDataset.source([3, 1, 2]), multiple_of=4, min_length=0, pad_elem=0)) 4
- dataset#
- multiple_of#
- min_length#
- pad_elem#
- xlstm_jax.dataset.lmeval_pipeline._pad_batch_multiple(batch, multiple_of=64, axis=1, pad_value=0, batch_size_pad=None)#
Pads a list of arrays to a common length defined as a multiple of a pad_mulitple value, with a certain value. Then the arrays are concatenated along axis 0.
- Parameters:
batch (list[numpy.ndarray]) – A list of np.ndarrays to be padded and concatenated
multiple_of (int) – A number that the padded size should be a multiple of.
axis (int) – The axis to be padded, typically a sequence dimension.
pad_value (int | float) – The padding value, typically zero.
batch_size_pad (int | None)
- Returns:
The concatenated, padded batch.
- Return type:
>>> np.allclose( ... _pad_batch_multiple([np.array([[1, 2]]), np.array([[11, 12, 13,]])], multiple_of=4, axis=1), ... np.array([[1, 2, 0, 0], [11, 12, 13, 0]])) True
- class xlstm_jax.dataset.lmeval_pipeline.PadSequenceInBatchDataset(dataset, batch_size, multiple_of=64, pad_value=0)#
Bases:
grain.python.MapDatasetCreates a dataset that has only full batches by padding elements.
This pads single elements (no batches) enabling having distributed batches over more devices. Assumes a dataset that consists of a flat dictionary of arrays.
- Parameters:
>>> from xlstm_jax.utils.pytree_utils import pytree_diff >>> pytree_diff(list(PadSequenceInBatchDataset(grain.MapDataset.source( ... [{"a": np.array([[3, 1, 5]])}, ... {"a": np.array([[2, 4]])}, ... {"a": np.array([[2, 4, 3, 3]])}, ... {"a": np.array([[2, 4]])}] ... ), batch_size=2, multiple_of=3 )), [ ... {"a": np.array([[3, 1, 5]])}, ... {"a": np.array([[2, 4, 0]])}, ... {"a": np.array([[2, 4, 3, 3, 0, 0]])}, ... {"a": np.array([[2, 4, 0, 0, 0, 0]])}])
- dataset#
- batch_size#
- multiple_of = 64#
- pad_value = 0#
- class xlstm_jax.dataset.lmeval_pipeline.SortedDataset(dataset, key, reverse=False)#
Bases:
grain.python.MapDatasetCreates a sorted dataset based on a key (applied to all items) and an existing dataset.
- Parameters:
dataset (grain.python.MapDataset) – The existing dataset
key (collections.abc.Callable) – Key Function to be applied for sorting
reverse (bool) – If the sorting should be ascending (False, default) or descending
>>> SortedDataset(grain.MapDataset.source([3, 1, 2]), lambda x: x)[0] 1 >>> SortedDataset(grain.MapDataset.source([3, 1, 2]), lambda x: x, reverse=True)[0] 3 >>> SortedDataset(grain.MapDataset.source( ... [{"inputs": np.array([[1, 2, 3]])}, ... {"inputs": np.array([[1, 2]])}]), token_length)[0] {'inputs': array([[1, 2]])}
- key#
- dataset#
- class xlstm_jax.dataset.lmeval_pipeline.MultihostSortedRemapDataset(dataset, global_batch_size, dataloader_host_count)#
Bases:
grain.python.MapDatasetThis implements an index re-shuffling for a SortedDataset. The problem: Given a sorted dataset, and multi-host dataloaders using .slice, the sorting is broken. Examplary dataset: [1 2 3 4 5 6 7 8]
Multi-host (standard slicing - assumed as input): [1 2 3 4] [5 6 7 8] Multi-host batched: [[1 2] [3 4]] [[5 6] [7 8]]
What we want actually for proper batching: [[1 2] [5 6]] [[3 4] [7 8]] such that the global batch still looks like: [[1 2 3 4] [5 6 7 8]]
- Parameters:
>>> ds = MultihostSortedRemapDataset( ... grain.MapDataset.source([1, 2, 3, 4, 5, 6, 7, 8]), ... global_batch_size=4, dataloader_host_count=2) >>> host_slices = [slice(0, 4), slice(4, 8)] >>> [list(ds.slice(host_slices[0]).batch(2)), list(ds.slice(host_slices[1]).batch(2))] [[array([1, 2]), array([5, 6])], [array([3, 4]), array([7, 8])]]
- dataset#
- global_batch_size#
- dataloader_host_count#
- xlstm_jax.dataset.lmeval_pipeline.lmeval_preprocessing_pipeline(dataloading_host_index, dataloading_host_count, global_mesh, dataset, global_batch_size, tokenizer_path, hf_access_token=None, tokenizer_cache_dir=None, eos_token_id=None, bos_token_id=None, worker_count=1, worker_buffer_size=1, padding_multiple=128, use_thread_prefetch=False)#
Create a mult-host dataloader for LMEval datasets for loglikelihood and loglikelihood_rolling tasks. This does not support generation tasks currently. Also, it just support recurrent models that can take infinite sequence lengths. For sequence_length limited models use the HFTokenizeLogLikelihoodRolling from lmeval_dataset.py.
- Parameters:
dataloading_host_index (int) – The index of the dataloading host. Will be used to select the correct shard of the dataset. In JAX, this is equivalent to
jax.process_index().dataloading_host_count (int) – The number of dataloading hosts. Will be used to determine the shard size. In JAX, this is equivalent to
jax.process_count().global_mesh (jax.sharding.Mesh) – The global mesh to shard the data over.
dataset (list[lm_eval.api.instance.Instance]) – The dataset to load. Should provide a __getitem__ method to access elements.
global_batch_size (int) – The global batch size.
tokenizer_path (str) – Path to the tokenizer.
hf_access_token (str | None) – The access token for HuggingFace.
tokenizer_cache_dir (str | None) – The cache directory for the tokenizer.
eos_token_id (int | None) – The token ID to use for the end-of-sequence token. If tokenizer_path is provided, the tokenizer’s EOS token ID is used.
bos_token_id (int | None) – The token ID to use for the beginning-of-sequence token. If tokenizer_path is provided, the tokenizer’s BOS token ID is used.
worker_count (int) – The number of workers to use. In grain, a single worker is usually sufficient, as the data loading is done in parallel across hosts.
worker_buffer_size (int) – The buffer size for the workers.
padding_multiple (int) – Pad to size being a multiple of.
use_thread_prefetch (bool) – Use thread prefetching instead of multiprocessing.
- Returns:
MultiHostDataLoadIterator for the lmeval dataset.
- Return type:
xlstm_jax.dataset.multihost_dataloading.MultiHostDataLoadIterator