xlstm_jax.dataset.input_pipeline_interface#
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
—
This file is a modified version of the file input_pipeline_interface.py from the maxtext project AI-Hypercomputer/maxtext.
Input pipeline
Attributes#
Functions#
|
Get list of processes loading data. |
|
|
|
Create a data iterator that mixes multiple datasets. |
Module Contents#
- xlstm_jax.dataset.input_pipeline_interface.LOGGER#
- xlstm_jax.dataset.input_pipeline_interface.GRAIN_AVAILABLE = True#
- xlstm_jax.dataset.input_pipeline_interface.DataIterator#
- xlstm_jax.dataset.input_pipeline_interface.get_process_loading_data(config, mesh)#
Get list of processes loading data.
- Parameters:
config (xlstm_jax.dataset.configs.DataConfig) – Config of the dataset to load.
mesh (jax.sharding.Mesh) – Global device mesh for sharding.
- Returns:
List of process indices that will load real data.
- Return type:
- xlstm_jax.dataset.input_pipeline_interface.create_data_iterator(config, mesh)#
- Parameters:
config (xlstm_jax.dataset.configs.DataConfig)
mesh (jax.sharding.Mesh)
- Return type:
DataIterator
- xlstm_jax.dataset.input_pipeline_interface.create_mixed_data_iterator(configs, mesh, dataset_weights=None)#
Create a data iterator that mixes multiple datasets.
Each individual dataset will be loaded, and the iterator will return batches where each batch element is from one of the datasets. The frequency of each dataset is determined by the dataset_weights.
- Parameters:
configs (list[xlstm_jax.dataset.configs.HFHubDataConfig | xlstm_jax.dataset.configs.GrainArrayRecordsDataConfig]) – List of DataConfig objects, determining the datasets to load.
mesh (jax.sharding.Mesh) – JAX mesh object. Used to distribute the data over multiple devices.
dataset_weights (list[float] | None) – Mixing weights for the datasets. If None, all datasets will have equal weight.
- Returns:
DataIterator object that can be used to iterate over the mixed dataset.
- Return type:
DataIterator