xlstm_jax.distributed.mesh_utils

xlstm_jax.distributed.mesh_utils#

Functions#

initialize_mesh(parallel_config[, device_array, ...])

Initialize the mesh for parallel training.

Module Contents#

xlstm_jax.distributed.mesh_utils.initialize_mesh(parallel_config, device_array=None, init_distributed_on_slurm=True)#

Initialize the mesh for parallel training.

Parameters:
  • parallel_config (xlstm_jax.models.configs.ParallelConfig) – A dictionary containing the parallelization parameters.

  • device_array (numpy.ndarray | None) – A numpy array containing the device structure. If None, all global devices are used.

  • init_distributed_on_slurm (bool) – Whether to initialize the JAX distributed system, i.e. multiprocess training, if SLURM environment variables are present. If False, the JAX distributed system is not initialized.

Returns:

The initialized mesh.

Return type:

jax.sharding.Mesh