xlstm_jax.dataset.batch#

Attributes#

Classes#

Batch

Batch of training data.

LLMBatch

Batch for LLM training.

LLMIndexedBatch

Batch for LLM data with document indices and sequence indices for correct ordering.

Module Contents#

xlstm_jax.dataset.batch.dataclass_kwonly#
class xlstm_jax.dataset.batch.Batch#

Batch of training data.

inputs: jax.Array#

The input data.

targets: jax.Array#

The target data.

class xlstm_jax.dataset.batch.LLMBatch#

Bases: Batch

Batch for LLM training.

Contains inputs and targets along with their respective positions and segmentations.

Padding is token 0. Segmentations are used to separate the subsequences.

Note that we use grain packing (FirstFitPackIterDataset) for grouping the smaller sequence. If the last subsequence does not fit in the same sequence, it will be used in the next sequence (not shown below).

Using packing & shift inputs (right):

targets: [1, 2, 3, E, 4, 5, E, 0] inputs: [E, 1, 2, 3, E, 4, 5, E] # The first token is always E (marking the beginning). targets_segmentation: [1, 1, 1, 1, 2, 2, 2, 0] inputs_segmentation: [1, 1, 1, 1, 2, 2, 2, 0] targets_position: [0, 1, 2, 3, 0, 1, 2, 0] inputs_position: [0, 1, 2, 3, 0, 1, 2, 0] doc borders: [1, 0, 0, 0, 1, 0, 0, 1]

Note that the two segmentations are identical. They would be different for prediction with multiple prefix tokens.

inputs_position: jax.Array#

np.int32

Type:

Positions of the input tokens. dtype

inputs_segmentation: jax.Array#

np.int32

Type:

Segmentation of the input tokens. 0 to indicate padding. dtype

targets_position: jax.Array#

np.int32

Type:

Positions of the target tokens. dtype

targets_segmentation: jax.Array#

np.int32

Type:

Segmentation of the target tokens. 0 to indicate padding. dtype

_document_borders: jax.Array | None = None#

Document borders for the input data. This buffer should only be used to explicitly overwrite the standard algorithm to calculate the document borders; for instance, if slicing the batch. Otherwise, use :func:get_document_borders to get the document borders. dtype: bool

get_document_borders()#

Get the document borders for the input data.

A token represents a document border if its previous target token has a different target segmentation. For instance, if the input segmentation is [1, 1, 2, 2, 2, 3], the document borders are [1, 0, 1, 0, 0, 1]. This mask can be useful for processing documents separately in a recurrent model, i.e. when to reset the hidden state. Note: If the last tokens are paddings, marking invalid tokens, the border between the last document and padding will also be marked as document border.

Returns:

A boolean array indicating the document borders.

Return type:

jax.Array

static from_inputs(inputs, targets=None)#

Create LLMBatch from inputs.

Helper function for quickly creating a default LLM Batch.

Parameters:
  • inputs (jax.Array) – The input data.

  • targets (jax.Array, optional) – The target data. If not provided, the inputs are used as targets and the inputs are shifted right by one.

Returns:

An LLMBatch with respective inputs and targets.

Return type:

LLMBatch

static get_dtype_struct(batch_size, max_length)#

Get the shape and dtype structure for LLMBatch.

Parameters:
  • batch_size (int) – The size of the batch.

  • max_length (int) – The maximum length of the sequences.

Returns:

An LLMBatch with jax.ShapeDtypeStruct typed components.

Return type:

LLMBatch

classmethod get_sample(batch_size, max_length)#

Get a real sample of an LLMBatch. Needed for compilation when using jax.debug.* in the model or anywhere else in the pipeline.

Parameters:
  • batch_size (int) – The size of the batch.

  • max_length (int) – The maximum length of the sequences.

Return type:

LLMBatch

inputs: jax.Array#

The input data.

targets: jax.Array#

The target data.

class xlstm_jax.dataset.batch.LLMIndexedBatch#

Bases: LLMBatch

Batch for LLM data with document indices and sequence indices for correct ordering.

document_idx equals zero means padding.

document_idx: jax.Array#

np.int32

Type:

Document indices for batch sequences. dtype

sequence_idx: jax.Array#

np.int32

Type:

Sequence indices within documents for batch sequences. dtype

static from_inputs(inputs, document_idx, sequence_idx, targets=None)#

Create LLMBatch from inputs.

Helper function for quickly creating a default LLM Batch.

Parameters:
  • inputs (jax.Array) – The input data.

  • targets (jax.Array) – The target data.

  • sequence_idx (jax.Array) – The sequence idx for each sample.

  • document_idx (jax.Array) – The document idx for each sample. A document might be composed of multiple sequences.

Returns:

An LLMBatch with respective inputs and targets.

Return type:

LLMIndexedBatch

static get_dtype_struct(batch_size, max_length)#

Get the shape and dtype structure for LLMIndexedBatch.

Parameters:
  • batch_size (int) – The size of the batch.

  • max_length (int) – The maximum length of the sequences.

Returns:

An LLMBatch with jax.ShapeDtypeStruct typed components.

Return type:

LLMIndexedBatch

inputs_position: jax.Array#

np.int32

Type:

Positions of the input tokens. dtype

inputs_segmentation: jax.Array#

np.int32

Type:

Segmentation of the input tokens. 0 to indicate padding. dtype

targets_position: jax.Array#

np.int32

Type:

Positions of the target tokens. dtype

targets_segmentation: jax.Array#

np.int32

Type:

Segmentation of the target tokens. 0 to indicate padding. dtype

_document_borders: jax.Array | None = None#

Document borders for the input data. This buffer should only be used to explicitly overwrite the standard algorithm to calculate the document borders; for instance, if slicing the batch. Otherwise, use :func:get_document_borders to get the document borders. dtype: bool

get_document_borders()#

Get the document borders for the input data.

A token represents a document border if its previous target token has a different target segmentation. For instance, if the input segmentation is [1, 1, 2, 2, 2, 3], the document borders are [1, 0, 1, 0, 0, 1]. This mask can be useful for processing documents separately in a recurrent model, i.e. when to reset the hidden state. Note: If the last tokens are paddings, marking invalid tokens, the border between the last document and padding will also be marked as document border.

Returns:

A boolean array indicating the document borders.

Return type:

jax.Array

classmethod get_sample(batch_size, max_length)#

Get a real sample of an LLMBatch. Needed for compilation when using jax.debug.* in the model or anywhere else in the pipeline.

Parameters:
  • batch_size (int) – The size of the batch.

  • max_length (int) – The maximum length of the sequences.

Return type:

LLMBatch

inputs: jax.Array#

The input data.

targets: jax.Array#

The target data.