xlstm_jax.dataset.batch#
Attributes#
Classes#
Batch of training data. |
|
Batch for LLM training. |
|
Batch for LLM data with document indices and sequence indices for correct ordering. |
Module Contents#
- xlstm_jax.dataset.batch.dataclass_kwonly#
- class xlstm_jax.dataset.batch.Batch#
Batch of training data.
- class xlstm_jax.dataset.batch.LLMBatch#
Bases:
BatchBatch for LLM training.
Contains inputs and targets along with their respective positions and segmentations.
Padding is token 0. Segmentations are used to separate the subsequences.
Note that we use grain packing (FirstFitPackIterDataset) for grouping the smaller sequence. If the last subsequence does not fit in the same sequence, it will be used in the next sequence (not shown below).
- Using packing & shift inputs (right):
targets: [1, 2, 3, E, 4, 5, E, 0] inputs: [E, 1, 2, 3, E, 4, 5, E] # The first token is always E (marking the beginning). targets_segmentation: [1, 1, 1, 1, 2, 2, 2, 0] inputs_segmentation: [1, 1, 1, 1, 2, 2, 2, 0] targets_position: [0, 1, 2, 3, 0, 1, 2, 0] inputs_position: [0, 1, 2, 3, 0, 1, 2, 0] doc borders: [1, 0, 0, 0, 1, 0, 0, 1]
Note that the two segmentations are identical. They would be different for prediction with multiple prefix tokens.
- inputs_segmentation: jax.Array#
np.int32
- Type:
Segmentation of the input tokens. 0 to indicate padding. dtype
- targets_segmentation: jax.Array#
np.int32
- Type:
Segmentation of the target tokens. 0 to indicate padding. dtype
- _document_borders: jax.Array | None = None#
Document borders for the input data. This buffer should only be used to explicitly overwrite the standard algorithm to calculate the document borders; for instance, if slicing the batch. Otherwise, use :func:get_document_borders to get the document borders. dtype: bool
- get_document_borders()#
Get the document borders for the input data.
A token represents a document border if its previous target token has a different target segmentation. For instance, if the input segmentation is [1, 1, 2, 2, 2, 3], the document borders are [1, 0, 1, 0, 0, 1]. This mask can be useful for processing documents separately in a recurrent model, i.e. when to reset the hidden state. Note: If the last tokens are paddings, marking invalid tokens, the border between the last document and padding will also be marked as document border.
- Returns:
A boolean array indicating the document borders.
- Return type:
- static from_inputs(inputs, targets=None)#
Create LLMBatch from inputs.
Helper function for quickly creating a default LLM Batch.
- static get_dtype_struct(batch_size, max_length)#
Get the shape and dtype structure for LLMBatch.
- Parameters:
- Returns:
An LLMBatch with
jax.ShapeDtypeStructtyped components.- Return type:
- classmethod get_sample(batch_size, max_length)#
Get a real sample of an LLMBatch. Needed for compilation when using jax.debug.* in the model or anywhere else in the pipeline.
- class xlstm_jax.dataset.batch.LLMIndexedBatch#
Bases:
LLMBatchBatch for LLM data with document indices and sequence indices for correct ordering.
document_idx equals zero means padding.
- sequence_idx: jax.Array#
np.int32
- Type:
Sequence indices within documents for batch sequences. dtype
- static from_inputs(inputs, document_idx, sequence_idx, targets=None)#
Create LLMBatch from inputs.
Helper function for quickly creating a default LLM Batch.
- Parameters:
- Returns:
An LLMBatch with respective inputs and targets.
- Return type:
- static get_dtype_struct(batch_size, max_length)#
Get the shape and dtype structure for LLMIndexedBatch.
- Parameters:
- Returns:
An LLMBatch with
jax.ShapeDtypeStructtyped components.- Return type:
- inputs_segmentation: jax.Array#
np.int32
- Type:
Segmentation of the input tokens. 0 to indicate padding. dtype
- targets_segmentation: jax.Array#
np.int32
- Type:
Segmentation of the target tokens. 0 to indicate padding. dtype
- _document_borders: jax.Array | None = None#
Document borders for the input data. This buffer should only be used to explicitly overwrite the standard algorithm to calculate the document borders; for instance, if slicing the batch. Otherwise, use :func:get_document_borders to get the document borders. dtype: bool
- get_document_borders()#
Get the document borders for the input data.
A token represents a document border if its previous target token has a different target segmentation. For instance, if the input segmentation is [1, 1, 2, 2, 2, 3], the document borders are [1, 0, 1, 0, 0, 1]. This mask can be useful for processing documents separately in a recurrent model, i.e. when to reset the hidden state. Note: If the last tokens are paddings, marking invalid tokens, the border between the last document and padding will also be marked as document border.
- Returns:
A boolean array indicating the document borders.
- Return type:
- classmethod get_sample(batch_size, max_length)#
Get a real sample of an LLMBatch. Needed for compilation when using jax.debug.* in the model or anywhere else in the pipeline.