xlstm_jax.dataset.multihost_dataloading#

Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

SPMD Multihost Dataloading Utilities.

See sholtodouglas/multihost_dataloading for a similar approach.

Attributes#

Classes#

MultiHostDataLoadIterator

Create a MultiHostDataLoadIterator.

Functions#

_build_global_shape_and_sharding(local_shape, global_mesh)

Create the global_shape and sharding based on the local_shape and global_mesh.

_form_global_array(path, array, global_mesh)

Put host sharded array into devices within a global sharded array.

_pad_array_to_shape(array_and_shape[, pad_value])

Pad an array to a given shape by given values. Array and shape are inside a shared tuple to

get_next_batch_sharded(local_iterator, global_mesh[, ...])

Splits the host loaded data equally over all devices. Optionally pad arrays for equal sizes.

Module Contents#

xlstm_jax.dataset.multihost_dataloading.LOGGER#
xlstm_jax.dataset.multihost_dataloading._build_global_shape_and_sharding(local_shape, global_mesh)#

Create the global_shape and sharding based on the local_shape and global_mesh.

Parameters:
Returns:

Global tensor shape, Named Sharding of the mesh

Return type:

tuple[tuple[int, Ellipsis], jax.sharding.NamedSharding]

xlstm_jax.dataset.multihost_dataloading._form_global_array(path, array, global_mesh)#

Put host sharded array into devices within a global sharded array.

Parameters:
  • path – Tree def path of the array in a PyTree struct (for debugging purposes only)

  • array (numpy.ndarray) – Distributed host array.

  • global_mesh (jax.sharding.Mesh) – Global mesh for the distributed array.

Returns:

Distributed device array

Return type:

jax.Array

xlstm_jax.dataset.multihost_dataloading._pad_array_to_shape(array_and_shape, pad_value=0)#

Pad an array to a given shape by given values. Array and shape are inside a shared tuple to enable easier mapping from a zip().

Parameters:
Returns:

Padded array.

>>> np.allclose(
...     _pad_array_to_shape((np.array([[1], [2]]), (3, 2)), pad_value=0),
...     np.array([[1, 0], [2, 0], [0, 0]]))
True
xlstm_jax.dataset.multihost_dataloading.get_next_batch_sharded(local_iterator, global_mesh, pad=False, pad_value=0)#

Splits the host loaded data equally over all devices. Optionally pad arrays for equal sizes.

Parameters:
Returns:

Optionally padded, sharded data array.

Return type:

xlstm_jax.common_types.PyTree

class xlstm_jax.dataset.multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh, iterator_length=None, dataset_size=None, reset_after_epoch=False, pad_shapes=False, pad_value=0)#

Create a MultiHostDataLoadIterator.

Wrapper around a tf.data.Dataset or Iterable to iterate over data in a multi-host setup. Folds get_next_batch_sharded into an iterator class, and supports breaking indefinite iterator into epochs.

Parameters:
  • dataloader (tf.data.Dataset | collections.abc.Iterable) – The dataloader to iterate over.

  • global_mesh (jax.sharding.Mesh) – The mesh to shard the data over.

  • iterator_length (int | None) – The length of the iterator. If provided, the iterator will stop after this many steps with a StopIteration exception. Otherwise, will continue over the iterator until it raises an exception itself.

  • dataset_size (int | None) – size of the dataset. If provided, will be returned by get_dataset_size. Otherwise, will return None. Can be used to communicate the dataset size to functions that use the iterator.

  • reset_after_epoch (bool) – Whether to reset the iterator between epochs or not. If True, the iterator will reset after each epoch, otherwise it will continue from where it left off. If you have an indefinite iterator (e.g. train iterator with grain and shuffle), this should be set to False. For un-shuffled iterators in grain (e.g. validation), this should be set to True.

  • pad_shapes (bool) – Whether to pad arrays to a common shape across all devices before merging.

  • pad_value (int | float) – Value to use for padding. Defaults to zero.

global_mesh#
dataloader#
iterator_length = None#
dataset_size = None#
reset_after_epoch = False#
state_set = False#
step_counter = 0#
pad_shapes = False#
pad_value = 0#
reset()#
get_state()#
Return type:

dict[str, Any]

set_state(state)#
Parameters:

state (dict[str, Any])