xlstm_jax.dataset.grain_transforms
==================================

.. py:module:: xlstm_jax.dataset.grain_transforms

.. autoapi-nested-parse::

   Copyright 2023 Google LLC.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

        https://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

   ---

   This file contains code from the maxtext project (https://github.com/AI-Hypercomputer/maxtext).
   Especially the file https://github.com/AI-Hypercomputer/maxtext/MaxText/input_pipeline/_input_pipeline_utils.py.

   Operations used by Grain



Classes
-------

.. autoapisummary::

   xlstm_jax.dataset.grain_transforms.HFNormalizeFeatures
   xlstm_jax.dataset.grain_transforms.HFPrefixTokenize
   xlstm_jax.dataset.grain_transforms.ParseFeatures
   xlstm_jax.dataset.grain_transforms.ParseArrayRecords
   xlstm_jax.dataset.grain_transforms.ParseTokenizedArrayRecords
   xlstm_jax.dataset.grain_transforms.NormalizeFeatures
   xlstm_jax.dataset.grain_transforms.ReformatPacking
   xlstm_jax.dataset.grain_transforms.ReformatLazyPacking
   xlstm_jax.dataset.grain_transforms.PadToMaxLength
   xlstm_jax.dataset.grain_transforms.ShiftData
   xlstm_jax.dataset.grain_transforms.InferSegmentations
   xlstm_jax.dataset.grain_transforms.CollateToBatch
   xlstm_jax.dataset.grain_transforms.HFTokenize
   xlstm_jax.dataset.grain_transforms.AddEODToken


Functions
---------

.. autoapisummary::

   xlstm_jax.dataset.grain_transforms.shift_right
   xlstm_jax.dataset.grain_transforms.shift_left
   xlstm_jax.dataset.grain_transforms.shift_and_refine


Module Contents
---------------

.. py:class:: HFNormalizeFeatures(column_name)

   Bases: :py:obj:`grain.python.MapTransform`


   Normalize feature keys for HuggingFace input.


   .. py:attribute:: column_name


   .. py:method:: map(features)


.. py:class:: HFPrefixTokenize(tokenizer, prefix_tokenizer, prefix_column_name = 'prefix', text_column_name = 'text', add_bos_token = True, add_eos_token = False, bos_token_id = None, eos_token_id = None, max_length = None, max_length_prefix = None)

   Bases: :py:obj:`grain.python.MapTransform`


   Merge prefix and predicted text


   .. py:attribute:: tokenizer


   .. py:attribute:: prefix_tokenizer


   .. py:attribute:: prefix_column_name
      :value: 'prefix'



   .. py:attribute:: text_column_name
      :value: 'text'



   .. py:attribute:: max_length
      :value: None



   .. py:attribute:: max_length_prefix
      :value: None



   .. py:attribute:: add_bos_token
      :value: True



   .. py:attribute:: add_eos_token
      :value: False



   .. py:attribute:: eos_token_id
      :value: None



   .. py:attribute:: bos_token_id
      :value: None



   .. py:method:: map(features)

      Map prefix / text string to fully padded and tokenized sequence.
      Prefixes are aligned in the array.

      :param features: Dictionary of inputs

      :returns: Dictionary of the outputs



.. py:class:: ParseFeatures(data_column, tokenize)

   Bases: :py:obj:`grain.python.MapTransform`


   Parse serialized example.


   .. py:attribute:: data_column


   .. py:method:: map(features)


.. py:class:: ParseArrayRecords(column_name)

   Bases: :py:obj:`grain.python.MapTransform`


   Parse serialized example from array_records dataset.


   .. py:attribute:: column_name


   .. py:method:: map(data)

      Map to parse array records.

      :param data: The bytestring-serialized example, e.g. b'Some Text'.

      :returns: Parsed data, a dictionary mapping the column_name to the deserialized string (text).



.. py:class:: ParseTokenizedArrayRecords(column_name)

   Bases: :py:obj:`grain.python.MapTransform`


   Parse serialized example from array_records dataset.


   .. py:attribute:: column_name


   .. py:method:: map(data)

      Map to parse array records.

      :param data: The bytestring-serialized data that has been tokenized, e.g. `b'[0, 9392, 1823]'`.

      :returns: Parsed data, a dictionary mapping the column_name to the deserialized string (text).



   .. py:method:: sequence_to_bytestring(sequence)
      :staticmethod:


      Convert a token sequence to a numpy bytestring.

      :param sequence: The sequence of tokens. If a numpy array is provided, it must be one-dimensional.

      :returns: The bytestring.



   .. py:method:: bytestring_to_sequence(bytestring)
      :staticmethod:


      Convert a numpy bytestring to a token sequence.

      :param bytestring: The bytestring.

      :returns: The token sequence.



.. py:class:: NormalizeFeatures(column_name, tokenize)

   Bases: :py:obj:`grain.python.MapTransform`


   Normalize text feature keys.


   .. py:attribute:: column_name


   .. py:attribute:: tokenize


   .. py:method:: map(features)


.. py:class:: ReformatPacking

   Bases: :py:obj:`grain.python.MapTransform`


   Reformat packing outputs.


   .. py:method:: map(data)
      :staticmethod:



.. py:class:: ReformatLazyPacking

   Bases: :py:obj:`grain.python.MapTransform`


   Reformat packing outputs for the lazy API.


   .. py:method:: map(data)
      :staticmethod:



.. py:class:: PadToMaxLength(max_length)

   Bases: :py:obj:`grain.python.MapTransform`


   Pads each input to the specified length.


   .. py:attribute:: max_length


   .. py:method:: map(data)

      Map to each element.



.. py:function:: shift_right(x, axis = 1, padding_value = 0, pad_by_first_element = False)

   Shift the input to the right by padding and slicing on axis.

   :param x: Input array to shift.
   :param axis: Axis to shift along.
   :param padding_value: Value to use for padding.
   :param pad_by_first_element: If True, does not use padding_value but instead the first element of the array on the
                                axis.

   :returns: Shifted array.


.. py:function:: shift_left(x, axis = 1, padding_value = 0)

   Shift the input to the left by padding and slicing on axis.

   :param x: Input array to shift.
   :param axis: Axis to shift along.
   :param padding_value: Value to use for padding.

   :returns: Shifted array.


.. py:function:: shift_and_refine(x, shift_target = True, axis = 1, padding_value = 0)

   Shift inputs or targets, and adjust segmentation.


.. py:class:: ShiftData(shift_target = True, eod_token_id = 0, pad_token_id = 0, axis = 1)

   Bases: :py:obj:`grain.python.MapTransform`


   Shift inputs/targets and refine annotations.


   .. py:attribute:: shift_target
      :value: True



   .. py:attribute:: eod_token_id
      :value: 0



   .. py:attribute:: pad_token_id
      :value: 0



   .. py:attribute:: axis
      :value: 1



   .. py:method:: map(data)


.. py:class:: InferSegmentations(eod_token_id)

   Bases: :py:obj:`grain.python.MapTransform`


   Infer the segmentation, i.e. document numbers, from the inputs.

   Uses the end-of-document token to infer breaks between documents. This is not needed
   when performing packing, where the segmentations are already set correctly, but is
   useful for grouped text preprocessed datasets, which do not have the segmentations set.

   :param eod_token_id: The token ID to use for the end-of-document token.


   .. py:attribute:: eod_token_id


   .. py:method:: map(data)

      Map to infer segmentations.



   .. py:method:: _get_positions(eod_mask)
      :staticmethod:


      Infer positions from end-of-document mask.



.. py:class:: CollateToBatch(batch_class, key_map = None)

   Bases: :py:obj:`grain.python.MapTransform`


   Collate data to batch.

   :param batch_class: A NamedTuple or dataclass to hold the batch data.
   :param key_map: Dictionary to map input to batch keys. Keys that are not found in the dictionary will be used as is.


   .. py:attribute:: batch_class


   .. py:attribute:: key_map
      :value: None



   .. py:method:: map(data)

      Map to collate data to batch.



.. py:class:: HFTokenize(create_tokenizer_fn, column_name = 'text', max_length = None, add_eod = True, eod_token_id = None)

   Bases: :py:obj:`grain.python.MapTransform`


   Tokenize text feature keys.


   .. py:attribute:: create_tokenizer_fn


   .. py:attribute:: tokenizer
      :value: None



   .. py:attribute:: column_name
      :value: 'text'



   .. py:attribute:: max_length
      :value: None



   .. py:attribute:: add_eod
      :value: True



   .. py:attribute:: eod_token_id
      :value: None



   .. py:method:: _lazy_init()


   .. py:method:: _tokenize(example)


   .. py:method:: map(data)


.. py:class:: AddEODToken(eod_token_id, add_eod = True, max_length = None)

   Bases: :py:obj:`grain.python.MapTransform`


   Add an end-of-document token to the inputs and targets.

   :param eod_token_id: The token ID to use for the end-of-document token.
   :param add_eod: Whether to add the EOD token. If false, the transform is a no-op.
   :param max_length: Maximum length of the sequence. If None, no truncation is performed.


   .. py:attribute:: eod_token_id


   .. py:attribute:: add_eod
      :value: True



   .. py:attribute:: max_length
      :value: None



   .. py:method:: map(data)

      Map to add EOD token.



