xlstm_jax.dataset.input_pipeline_interface#

Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

This file is a modified version of the file input_pipeline_interface.py from the maxtext project AI-Hypercomputer/maxtext.

Input pipeline

Attributes#

Functions#

get_process_loading_data(config, mesh)

Get list of processes loading data.

create_data_iterator(config, mesh)

create_mixed_data_iterator(configs, mesh[, ...])

Create a data iterator that mixes multiple datasets.

Module Contents#

xlstm_jax.dataset.input_pipeline_interface.LOGGER#
xlstm_jax.dataset.input_pipeline_interface.GRAIN_AVAILABLE = True#
xlstm_jax.dataset.input_pipeline_interface.DataIterator#
xlstm_jax.dataset.input_pipeline_interface.get_process_loading_data(config, mesh)#

Get list of processes loading data.

Parameters:
Returns:

List of process indices that will load real data.

Return type:

list[int]

xlstm_jax.dataset.input_pipeline_interface.create_data_iterator(config, mesh)#
Parameters:
Return type:

DataIterator

xlstm_jax.dataset.input_pipeline_interface.create_mixed_data_iterator(configs, mesh, dataset_weights=None)#

Create a data iterator that mixes multiple datasets.

Each individual dataset will be loaded, and the iterator will return batches where each batch element is from one of the datasets. The frequency of each dataset is determined by the dataset_weights.

Parameters:
Returns:

DataIterator object that can be used to iterate over the mixed dataset.

Return type:

DataIterator