xlstm_jax.dataset.grain_transforms

Contents

xlstm_jax.dataset.grain_transforms#

Copyright 2023 Google LLC.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

This file contains code from the maxtext project (AI-Hypercomputer/maxtext). Especially the file AI-Hypercomputer/maxtext.

Operations used by Grain

Classes#

HFNormalizeFeatures

Normalize feature keys for HuggingFace input.

HFPrefixTokenize

Merge prefix and predicted text

ParseFeatures

Parse serialized example.

ParseArrayRecords

Parse serialized example from array_records dataset.

ParseTokenizedArrayRecords

Parse serialized example from array_records dataset.

NormalizeFeatures

Normalize text feature keys.

ReformatPacking

Reformat packing outputs.

ReformatLazyPacking

Reformat packing outputs for the lazy API.

PadToMaxLength

Pads each input to the specified length.

ShiftData

Shift inputs/targets and refine annotations.

InferSegmentations

Infer the segmentation, i.e. document numbers, from the inputs.

CollateToBatch

Collate data to batch.

HFTokenize

Tokenize text feature keys.

AddEODToken

Add an end-of-document token to the inputs and targets.

Functions#

shift_right(x[, axis, padding_value, pad_by_first_element])

Shift the input to the right by padding and slicing on axis.

shift_left(x[, axis, padding_value])

Shift the input to the left by padding and slicing on axis.

shift_and_refine(x[, shift_target, axis, padding_value])

Shift inputs or targets, and adjust segmentation.

Module Contents#

class xlstm_jax.dataset.grain_transforms.HFNormalizeFeatures(column_name)#

Bases: grain.python.MapTransform

Normalize feature keys for HuggingFace input.

Parameters:

column_name (str)

column_name#
map(features)#
Parameters:

features (dict[str, list[int]])

Return type:

dict[str, numpy.ndarray]

class xlstm_jax.dataset.grain_transforms.HFPrefixTokenize(tokenizer, prefix_tokenizer, prefix_column_name='prefix', text_column_name='text', add_bos_token=True, add_eos_token=False, bos_token_id=None, eos_token_id=None, max_length=None, max_length_prefix=None)#

Bases: grain.python.MapTransform

Merge prefix and predicted text

Parameters:
  • prefix_column_name (str)

  • text_column_name (str)

  • add_bos_token (bool)

  • add_eos_token (bool)

  • bos_token_id (int | None)

  • eos_token_id (int | None)

  • max_length (int | None)

  • max_length_prefix (int | None)

tokenizer#
prefix_tokenizer#
prefix_column_name = 'prefix'#
text_column_name = 'text'#
max_length = None#
max_length_prefix = None#
add_bos_token = True#
add_eos_token = False#
eos_token_id = None#
bos_token_id = None#
map(features)#

Map prefix / text string to fully padded and tokenized sequence. Prefixes are aligned in the array.

Parameters:

features (dict[str, str]) – Dictionary of inputs

Returns:

Dictionary of the outputs

Return type:

dict[str, numpy.ndarray]

class xlstm_jax.dataset.grain_transforms.ParseFeatures(data_column, tokenize)#

Bases: grain.python.MapTransform

Parse serialized example.

Parameters:
  • data_column (str)

  • tokenize (bool)

data_column#
map(features)#
class xlstm_jax.dataset.grain_transforms.ParseArrayRecords(column_name)#

Bases: grain.python.MapTransform

Parse serialized example from array_records dataset.

Parameters:

column_name (str)

column_name#
map(data)#

Map to parse array records.

Parameters:

data (bytes) – The bytestring-serialized example, e.g. b’Some Text’.

Returns:

Parsed data, a dictionary mapping the column_name to the deserialized string (text).

Return type:

dict[str, str]

class xlstm_jax.dataset.grain_transforms.ParseTokenizedArrayRecords(column_name)#

Bases: grain.python.MapTransform

Parse serialized example from array_records dataset.

Parameters:

column_name (str)

column_name#
map(data)#

Map to parse array records.

Parameters:

data (bytes) – The bytestring-serialized data that has been tokenized, e.g. b’[0, 9392, 1823]’.

Returns:

Parsed data, a dictionary mapping the column_name to the deserialized string (text).

Return type:

dict[str, int]

static sequence_to_bytestring(sequence)#

Convert a token sequence to a numpy bytestring.

Parameters:

sequence (list[int] | numpy.ndarray) – The sequence of tokens. If a numpy array is provided, it must be one-dimensional.

Returns:

The bytestring.

Return type:

bytes

static bytestring_to_sequence(bytestring)#

Convert a numpy bytestring to a token sequence.

Parameters:

bytestring (bytes) – The bytestring.

Returns:

The token sequence.

Return type:

list[int]

class xlstm_jax.dataset.grain_transforms.NormalizeFeatures(column_name, tokenize)#

Bases: grain.python.MapTransform

Normalize text feature keys.

Parameters:
  • column_name (str)

  • tokenize (bool)

column_name#
tokenize#
map(features)#
class xlstm_jax.dataset.grain_transforms.ReformatPacking#

Bases: grain.python.MapTransform

Reformat packing outputs.

static map(data)#
Parameters:

data (tuple[dict[str, numpy.ndarray]])

Return type:

dict[str, numpy.ndarray]

class xlstm_jax.dataset.grain_transforms.ReformatLazyPacking#

Bases: grain.python.MapTransform

Reformat packing outputs for the lazy API.

static map(data)#
Parameters:

data (dict[str, numpy.ndarray])

Return type:

dict[str, numpy.ndarray]

class xlstm_jax.dataset.grain_transforms.PadToMaxLength(max_length)#

Bases: grain.python.MapTransform

Pads each input to the specified length.

Parameters:

max_length (int)

max_length#
map(data)#

Map to each element.

Return type:

dict[str, numpy.ndarray]

xlstm_jax.dataset.grain_transforms.shift_right(x, axis=1, padding_value=0, pad_by_first_element=False)#

Shift the input to the right by padding and slicing on axis.

Parameters:
  • x (numpy.ndarray) – Input array to shift.

  • axis (int) – Axis to shift along.

  • padding_value (int) – Value to use for padding.

  • pad_by_first_element (bool) – If True, does not use padding_value but instead the first element of the array on the axis.

Returns:

Shifted array.

Return type:

numpy.ndarray

xlstm_jax.dataset.grain_transforms.shift_left(x, axis=1, padding_value=0)#

Shift the input to the left by padding and slicing on axis.

Parameters:
  • x (numpy.ndarray) – Input array to shift.

  • axis (int) – Axis to shift along.

  • padding_value (int) – Value to use for padding.

Returns:

Shifted array.

Return type:

numpy.ndarray

xlstm_jax.dataset.grain_transforms.shift_and_refine(x, shift_target=True, axis=1, padding_value=0)#

Shift inputs or targets, and adjust segmentation.

Parameters:
Return type:

dict[str, numpy.ndarray]

class xlstm_jax.dataset.grain_transforms.ShiftData(shift_target=True, eod_token_id=0, pad_token_id=0, axis=1)#

Bases: grain.python.MapTransform

Shift inputs/targets and refine annotations.

Parameters:
  • shift_target (bool)

  • eod_token_id (int)

  • pad_token_id (int)

  • axis (int)

shift_target = True#
eod_token_id = 0#
pad_token_id = 0#
axis = 1#
map(data)#
Parameters:

data (dict[str, numpy.ndarray])

Return type:

dict[str, numpy.ndarray]

class xlstm_jax.dataset.grain_transforms.InferSegmentations(eod_token_id)#

Bases: grain.python.MapTransform

Infer the segmentation, i.e. document numbers, from the inputs.

Uses the end-of-document token to infer breaks between documents. This is not needed when performing packing, where the segmentations are already set correctly, but is useful for grouped text preprocessed datasets, which do not have the segmentations set.

Parameters:

eod_token_id (int) – The token ID to use for the end-of-document token.

eod_token_id#
map(data)#

Map to infer segmentations.

Parameters:

data (dict[str, numpy.ndarray])

Return type:

dict[str, numpy.ndarray]

static _get_positions(eod_mask)#

Infer positions from end-of-document mask.

Parameters:

eod_mask (numpy.ndarray)

Return type:

numpy.ndarray

class xlstm_jax.dataset.grain_transforms.CollateToBatch(batch_class, key_map=None)#

Bases: grain.python.MapTransform

Collate data to batch.

Parameters:
  • batch_class (NamedTuple) – A NamedTuple or dataclass to hold the batch data.

  • key_map (dict[str, str] | None) – Dictionary to map input to batch keys. Keys that are not found in the dictionary will be used as is.

batch_class#
key_map = None#
map(data)#

Map to collate data to batch.

Parameters:

data (dict[str, numpy.ndarray | jax.Array])

Return type:

xlstm_jax.dataset.batch.LLMBatch

class xlstm_jax.dataset.grain_transforms.HFTokenize(create_tokenizer_fn, column_name='text', max_length=None, add_eod=True, eod_token_id=None)#

Bases: grain.python.MapTransform

Tokenize text feature keys.

Parameters:
create_tokenizer_fn#
tokenizer = None#
column_name = 'text'#
max_length = None#
add_eod = True#
eod_token_id = None#
_lazy_init()#
_tokenize(example)#
Parameters:

example (str)

Return type:

transformers.tokenization_utils_base.BatchEncoding[str, list[int]]

map(data)#
Parameters:

data (dict[str, str])

Return type:

dict[str, list[int]]

class xlstm_jax.dataset.grain_transforms.AddEODToken(eod_token_id, add_eod=True, max_length=None)#

Bases: grain.python.MapTransform

Add an end-of-document token to the inputs and targets.

Parameters:
  • eod_token_id (int) – The token ID to use for the end-of-document token.

  • add_eod (bool) – Whether to add the EOD token. If false, the transform is a no-op.

  • max_length (int | None) – Maximum length of the sequence. If None, no truncation is performed.

eod_token_id#
add_eod = True#
max_length = None#
map(data)#

Map to add EOD token.

Parameters:

data (dict[str, list[int]])

Return type:

dict[str, list[int]]