xlstm_jax.dataset.grain_data_processing#
Attributes#
Classes#
Dataset wrapper to pad the dataset to be a multiple of the global batch size. |
Functions#
|
Pipeline for preprocessing an array_records or huggingface dataset. |
|
Load a dataset, create the preprocessing pipeline and return a multihost data-loading iterator. |
|
Pads the dataset to match a multiple of the global batch size. |
|
Take all files located at dataset_path and load it as grain.ArrayRecordDataSource. |
|
Load a dataset from HuggingFace. |
Module Contents#
- xlstm_jax.dataset.grain_data_processing.LOGGER#
- xlstm_jax.dataset.grain_data_processing.preprocess_dataset(dataloading_host_index, dataloading_host_count, dataset, data_column_name, tokenize, global_batch_size, max_target_length, shuffle, data_shuffle_seed, tokenizer_path=None, hf_access_token=None, add_bos=True, add_eos=True, add_eod=True, grain_packing=False, grain_packing_bin_count=None, shift=True, drop_remainder=True, num_epochs=None, tokenizer_cache_dir=None, max_steps_per_epoch=None, eod_token_id=None)#
Pipeline for preprocessing an array_records or huggingface dataset.
- Parameters:
dataloading_host_index (int) – The index of the data loading host. Will be used to select the correct shard of the dataset. In JAX, this is equivalent to
jax.process_index().dataloading_host_count (int) – The number of data loading hosts. Will be used to determine the shard size. In JAX, this is equivalent to
jax.process_count().dataset (Any) – The dataset to load. Should provide a __getitem__ method to access elements.
data_column_name (str) – The column name for the data in the dataset.
tokenize (bool) – Whether to tokenize the data.
global_batch_size (int) – The global batch size.
max_target_length (int) – The maximum target length.
shuffle (bool) – Whether to shuffle the dataset.
data_shuffle_seed (int) – The shuffle seed.
tokenizer_path (str | None) – The path to the tokenizer.
hf_access_token (str | None) – The access token for HuggingFace.
add_bos (bool) – Whether to add the beginning of sequence token.
add_eos (bool) – Whether to add the end of sequence token.
add_eod (bool) – Whether to add an end of document token.
grain_packing (bool) – Whether to perform packing of the data. This is useful for datasets with a lot of padding, as batch elements will be packed together in a sequence to reduce the amount of padding. This can improve throughput efficiency. Note: if packing is enabled, the length of the iterator cannot be determined in advance and is likely incorrect in the iterator (will be set to maximum number of batches).
grain_packing_bin_count (int | None) – The number of packing bins to use. If not provided, the bin count will be set to the batch size. It can be beneficial to increase the packing bins to reduce padding.
shift (bool) – Whether to shift the input data to create the target data.
drop_remainder (bool) – Whether to drop the remainder of the dataset. Note that in case of providing a number of epochs, the last batch of all epochs together will be dropped if this is set to True. If set to False, the last batch of all epochs together will be included in the iterator.
num_epochs (int | None) – The number of epochs to train for. The dataset will be repeated for so many epochs, and the shuffle order will be different for each epoch. If None, the dataset will be repeated infinitely. Note that batches of an epoch can spill over into the first batch of the next epoch, to avoid dropping data. The argument drop_remainder controls whether the very last batch of all epochs together is dropped. By default, use None (infinite epochs) for training and validation.
tokenizer_cache_dir (str | None) – The cache directory for the tokenizer.
max_steps_per_epoch (int | None) – The maximum number of steps per epoch. If provided, the iterator will stop after this many steps with a
StopIterationexception. Otherwise, will continue over the iterator until all batches are consumed.eod_token_id (int | None) – The token ID to use for the end-of-document token. If tokenizer_path is provided, the tokenizer’s EOD token ID is used.
- Returns:
The preprocessed grain dataset and the original data source.
- Return type:
tuple[grain.python.IterDataset, grain.python.RandomAccessDataSource]
- xlstm_jax.dataset.grain_data_processing.make_grain_iterator(configs, global_mesh, process_indices, dataset_weights=None)#
Load a dataset, create the preprocessing pipeline and return a multihost data-loading iterator.
- Parameters:
configs (xlstm_jax.dataset.configs.GrainArrayRecordsDataConfig | xlstm_jax.dataset.configs.HFHubDataConfig | list[xlstm_jax.dataset.configs.GrainArrayRecordsDataConfig | xlstm_jax.dataset.configs.HFHubDataConfig]) – dataset configuration object for huggingface or arrayrecords dataset. If multiple configs are provided, the datasets will be loaded in parallel and the data will be interleaved in a mixing style. NOTE: the global batch size, worker count, worker buffer size, drop remainder, and batch rampup will be only used from the first config. The other configs are assumed to have the same values. Otherwise, warnings will be raised.
global_mesh (jax.sharding.Mesh) – The global mesh to shard the data over.
process_indices (list[int]) – List of process indices that should load the real data. This is used to determine the data loading host index and host count.
dataset_weights (list[float] | None) – The weights for the datasets. If provided, the datasets will be mixed according to the weights. Otherwise, a uniform mixing is used. If a single dataset is provided, the weights are ignored.
- Returns:
data-loading iterator (for training or evaluation).
- Return type:
xlstm_jax.dataset.multihost_dataloading.MultiHostDataLoadIterator
- class xlstm_jax.dataset.grain_data_processing.PaddedDataset(dataset, full_dataset_length, column_name)#
Bases:
grain.python.RandomAccessDataSourceDataset wrapper to pad the dataset to be a multiple of the global batch size.
- Parameters:
- dataset#
- full_dataset_length#
- column_name#
- property empty_sequence#
Returns and empty sequence for padding, depending on the type of dataset.
- xlstm_jax.dataset.grain_data_processing.pad_dataset(dataset, global_batch_size, column_name)#
Pads the dataset to match a multiple of the global batch size.
- Parameters:
- Returns:
The padded dataset.
- Return type:
- xlstm_jax.dataset.grain_data_processing.load_array_record_dataset(dataset_path, file_extension='.arecord')#
Take all files located at dataset_path and load it as grain.ArrayRecordDataSource.
Assumes that the filenames are multiple shards where the shard idx is in the filename, e.g. train_000001.arecord’. We load the files in the order of the shard idx.
- Parameters:
dataset_path (pathlib.Path | str) – Path to the dataset folder, which contains .arecord files.
file_extension (str) – The file extension of the dataset files. Default is ‘.arecord’.
- Returns:
The dataset as grain.ArrayRecordDataSource.
- Return type:
grain.ArrayRecordDataSource
- xlstm_jax.dataset.grain_data_processing.load_huggingface_dataset(config)#
Load a dataset from HuggingFace.
- Parameters:
config (xlstm_jax.dataset.configs.HFHubDataConfig) – The HFHubDataConfig object.
- Returns:
The loaded dataset.
- Return type:
datasets.Dataset