xlstm_jax.dataset.grain_transforms#
Copyright 2023 Google LLC.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
—
This file contains code from the maxtext project (AI-Hypercomputer/maxtext). Especially the file AI-Hypercomputer/maxtext.
Operations used by Grain
Classes#
Normalize feature keys for HuggingFace input. |
|
Merge prefix and predicted text |
|
Parse serialized example. |
|
Parse serialized example from array_records dataset. |
|
Parse serialized example from array_records dataset. |
|
Normalize text feature keys. |
|
Reformat packing outputs. |
|
Reformat packing outputs for the lazy API. |
|
Pads each input to the specified length. |
|
Shift inputs/targets and refine annotations. |
|
Infer the segmentation, i.e. document numbers, from the inputs. |
|
Collate data to batch. |
|
Tokenize text feature keys. |
|
Add an end-of-document token to the inputs and targets. |
Functions#
|
Shift the input to the right by padding and slicing on axis. |
|
Shift the input to the left by padding and slicing on axis. |
|
Shift inputs or targets, and adjust segmentation. |
Module Contents#
- class xlstm_jax.dataset.grain_transforms.HFNormalizeFeatures(column_name)#
Bases:
grain.python.MapTransformNormalize feature keys for HuggingFace input.
- Parameters:
column_name (str)
- column_name#
- class xlstm_jax.dataset.grain_transforms.HFPrefixTokenize(tokenizer, prefix_tokenizer, prefix_column_name='prefix', text_column_name='text', add_bos_token=True, add_eos_token=False, bos_token_id=None, eos_token_id=None, max_length=None, max_length_prefix=None)#
Bases:
grain.python.MapTransformMerge prefix and predicted text
- Parameters:
- tokenizer#
- prefix_tokenizer#
- prefix_column_name = 'prefix'#
- text_column_name = 'text'#
- max_length = None#
- max_length_prefix = None#
- add_bos_token = True#
- add_eos_token = False#
- eos_token_id = None#
- bos_token_id = None#
- class xlstm_jax.dataset.grain_transforms.ParseFeatures(data_column, tokenize)#
Bases:
grain.python.MapTransformParse serialized example.
- data_column#
- map(features)#
- class xlstm_jax.dataset.grain_transforms.ParseArrayRecords(column_name)#
Bases:
grain.python.MapTransformParse serialized example from array_records dataset.
- Parameters:
column_name (str)
- column_name#
- class xlstm_jax.dataset.grain_transforms.ParseTokenizedArrayRecords(column_name)#
Bases:
grain.python.MapTransformParse serialized example from array_records dataset.
- Parameters:
column_name (str)
- column_name#
- map(data)#
Map to parse array records.
- static sequence_to_bytestring(sequence)#
Convert a token sequence to a numpy bytestring.
- Parameters:
sequence (list[int] | numpy.ndarray) – The sequence of tokens. If a numpy array is provided, it must be one-dimensional.
- Returns:
The bytestring.
- Return type:
- class xlstm_jax.dataset.grain_transforms.NormalizeFeatures(column_name, tokenize)#
Bases:
grain.python.MapTransformNormalize text feature keys.
- column_name#
- tokenize#
- map(features)#
- class xlstm_jax.dataset.grain_transforms.ReformatPacking#
Bases:
grain.python.MapTransformReformat packing outputs.
- static map(data)#
- Parameters:
data (tuple[dict[str, numpy.ndarray]])
- Return type:
- class xlstm_jax.dataset.grain_transforms.ReformatLazyPacking#
Bases:
grain.python.MapTransformReformat packing outputs for the lazy API.
- static map(data)#
- Parameters:
data (dict[str, numpy.ndarray])
- Return type:
- class xlstm_jax.dataset.grain_transforms.PadToMaxLength(max_length)#
Bases:
grain.python.MapTransformPads each input to the specified length.
- Parameters:
max_length (int)
- max_length#
- map(data)#
Map to each element.
- Return type:
- xlstm_jax.dataset.grain_transforms.shift_right(x, axis=1, padding_value=0, pad_by_first_element=False)#
Shift the input to the right by padding and slicing on axis.
- Parameters:
x (numpy.ndarray) – Input array to shift.
axis (int) – Axis to shift along.
padding_value (int) – Value to use for padding.
pad_by_first_element (bool) – If True, does not use padding_value but instead the first element of the array on the axis.
- Returns:
Shifted array.
- Return type:
- xlstm_jax.dataset.grain_transforms.shift_left(x, axis=1, padding_value=0)#
Shift the input to the left by padding and slicing on axis.
- Parameters:
x (numpy.ndarray) – Input array to shift.
axis (int) – Axis to shift along.
padding_value (int) – Value to use for padding.
- Returns:
Shifted array.
- Return type:
- xlstm_jax.dataset.grain_transforms.shift_and_refine(x, shift_target=True, axis=1, padding_value=0)#
Shift inputs or targets, and adjust segmentation.
- Parameters:
- Return type:
- class xlstm_jax.dataset.grain_transforms.ShiftData(shift_target=True, eod_token_id=0, pad_token_id=0, axis=1)#
Bases:
grain.python.MapTransformShift inputs/targets and refine annotations.
- shift_target = True#
- eod_token_id = 0#
- pad_token_id = 0#
- axis = 1#
- map(data)#
- Parameters:
data (dict[str, numpy.ndarray])
- Return type:
- class xlstm_jax.dataset.grain_transforms.InferSegmentations(eod_token_id)#
Bases:
grain.python.MapTransformInfer the segmentation, i.e. document numbers, from the inputs.
Uses the end-of-document token to infer breaks between documents. This is not needed when performing packing, where the segmentations are already set correctly, but is useful for grouped text preprocessed datasets, which do not have the segmentations set.
- Parameters:
eod_token_id (int) – The token ID to use for the end-of-document token.
- eod_token_id#
- map(data)#
Map to infer segmentations.
- Parameters:
data (dict[str, numpy.ndarray])
- Return type:
- static _get_positions(eod_mask)#
Infer positions from end-of-document mask.
- Parameters:
eod_mask (numpy.ndarray)
- Return type:
- class xlstm_jax.dataset.grain_transforms.CollateToBatch(batch_class, key_map=None)#
Bases:
grain.python.MapTransformCollate data to batch.
- Parameters:
- batch_class#
- key_map = None#
- map(data)#
Map to collate data to batch.
- Parameters:
data (dict[str, numpy.ndarray | jax.Array])
- Return type:
- class xlstm_jax.dataset.grain_transforms.HFTokenize(create_tokenizer_fn, column_name='text', max_length=None, add_eod=True, eod_token_id=None)#
Bases:
grain.python.MapTransformTokenize text feature keys.
- Parameters:
create_tokenizer_fn (collections.abc.Callable[Ellipsis, transformers.AutoTokenizer])
column_name (str)
max_length (int | None)
add_eod (bool)
eod_token_id (int | None)
- create_tokenizer_fn#
- tokenizer = None#
- column_name = 'text'#
- max_length = None#
- add_eod = True#
- eod_token_id = None#
- _lazy_init()#
- _tokenize(example)#