xlstm_jax.distributed.data_parallel#
Functions#
|
Shard parameters across the given mesh axis. |
|
Gathering with averaging gradients across replicas. |
|
Gather parameters from all replicas across the given axis. |
|
Shard parameters of a module across replicas. |
|
Synchronize gradients across devices. |
Module Contents#
- xlstm_jax.distributed.data_parallel.shard_params(params, axis_name, min_weight_size=2**18)#
Shard parameters across the given mesh axis.
- Parameters:
- Returns:
PyTree of same structure as params, but with leaves sharded over new axis if possible.
- Return type:
xlstm_jax.common_types.PyTree
- xlstm_jax.distributed.data_parallel.gather_array_with_mean_grads(x, axis, axis_name, gather_dtype=None, grad_scatter_dtype=None)#
Gathering with averaging gradients across replicas.
- Parameters:
x (jax.Array) – The array to gather.
axis (int) – The axis of the array to gather across.
axis_name (str) – The axis name of the mesh to gather across.
gather_dtype (jax.numpy.dtype | None) – The dtype to cast the array to before gathering. If None, no casting is performed.
grad_scatter_dtype (jax.numpy.dtype | None) – The dtype to cast the gradients to before scattering. If None, the dtype of x is used.
- Returns:
The gathered array with a gradient function that averages across replicas.
- Return type:
- xlstm_jax.distributed.data_parallel.gather_params(params, axis_name, gather_dtype=None, grad_scatter_dtype=None)#
Gather parameters from all replicas across the given axis.
- Parameters:
params (xlstm_jax.common_types.PyTree) – The parameters to gather.
axis_name (str) – The axis to gather parameters across.
gather_dtype (jax.numpy.dtype | None) – The dtype to cast the parameters to before gathering. If None, no casting is performed.
grad_scatter_dtype (jax.numpy.dtype | None) – The dtype to cast the gradients to before scattering. If None, the dtype of the parameters is used.
- Returns:
PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object.
- Return type:
xlstm_jax.common_types.PyTree
- xlstm_jax.distributed.data_parallel.shard_module_params(target, axis_name, min_weight_size=2**18, gather_dtype=None, grad_scatter_dtype=None)#
Shard parameters of a module across replicas.
- Parameters:
target (flax.linen.Module | collections.abc.Callable) – The module to shard.
axis_name (str) – The axis name to shard parameters across.
min_weight_size (int) – The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.
gather_dtype (jax.numpy.dtype | None) – The dtype to cast the parameters to before gathering. If None, no casting is performed.
grad_scatter_dtype (jax.numpy.dtype | None) – The dtype to cast the gradients to before scattering. If None, the dtype of the parameters is used.
- Returns:
The module with sharded parameters.
- Return type:
flax.linen.Module | collections.abc.Callable
- xlstm_jax.distributed.data_parallel.sync_gradients(grads, axis_names)#
Synchronize gradients across devices.
Gradients for parameters that are replicated over a given axis are averaged across devices. Parameters that are partitioned over a given axis are considered to already have a mean of the gradients on each device, and hence do not need to be altered.
- Parameters:
grads (xlstm_jax.common_types.PyTree) – The gradients to synchronize.
axis_names (collections.abc.Sequence[str]) – The axis names to synchronize gradients across.
- Returns:
The gradients averaged over the specified axes if they are replicated.
- Return type:
xlstm_jax.common_types.PyTree