xlstm_jax.distributed.mesh_utils#
Functions#
|
Initialize the mesh for parallel training. |
Module Contents#
- xlstm_jax.distributed.mesh_utils.initialize_mesh(parallel_config, device_array=None, init_distributed_on_slurm=True)#
Initialize the mesh for parallel training.
- Parameters:
parallel_config (xlstm_jax.models.configs.ParallelConfig) – A dictionary containing the parallelization parameters.
device_array (numpy.ndarray | None) – A numpy array containing the device structure. If None, all global devices are used.
init_distributed_on_slurm (bool) – Whether to initialize the JAX distributed system, i.e. multiprocess training, if SLURM environment variables are present. If False, the JAX distributed system is not initialized.
- Returns:
The initialized mesh.
- Return type: