xlstm_jax.distributed.data_parallel#

Functions#

shard_params(params, axis_name[, min_weight_size])

Shard parameters across the given mesh axis.

gather_array_with_mean_grads(x, axis, axis_name[, ...])

Gathering with averaging gradients across replicas.

gather_params(params, axis_name[, gather_dtype, ...])

Gather parameters from all replicas across the given axis.

shard_module_params(target, axis_name[, ...])

Shard parameters of a module across replicas.

sync_gradients(grads, axis_names)

Synchronize gradients across devices.

Module Contents#

xlstm_jax.distributed.data_parallel.shard_params(params, axis_name, min_weight_size=2**18)#

Shard parameters across the given mesh axis.

Parameters:
  • params (xlstm_jax.common_types.PyTree) – The parameters to shard.

  • axis_name (str) – The axis to shard parameters across.

  • min_weight_size (int) – The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.

Returns:

PyTree of same structure as params, but with leaves sharded over new axis if possible.

Return type:

xlstm_jax.common_types.PyTree

xlstm_jax.distributed.data_parallel.gather_array_with_mean_grads(x, axis, axis_name, gather_dtype=None, grad_scatter_dtype=None)#

Gathering with averaging gradients across replicas.

Parameters:
  • x (jax.Array) – The array to gather.

  • axis (int) – The axis of the array to gather across.

  • axis_name (str) – The axis name of the mesh to gather across.

  • gather_dtype (jax.numpy.dtype | None) – The dtype to cast the array to before gathering. If None, no casting is performed.

  • grad_scatter_dtype (jax.numpy.dtype | None) – The dtype to cast the gradients to before scattering. If None, the dtype of x is used.

Returns:

The gathered array with a gradient function that averages across replicas.

Return type:

jax.Array

xlstm_jax.distributed.data_parallel.gather_params(params, axis_name, gather_dtype=None, grad_scatter_dtype=None)#

Gather parameters from all replicas across the given axis.

Parameters:
  • params (xlstm_jax.common_types.PyTree) – The parameters to gather.

  • axis_name (str) – The axis to gather parameters across.

  • gather_dtype (jax.numpy.dtype | None) – The dtype to cast the parameters to before gathering. If None, no casting is performed.

  • grad_scatter_dtype (jax.numpy.dtype | None) – The dtype to cast the gradients to before scattering. If None, the dtype of the parameters is used.

Returns:

PyTree of same structure as params, but with leaves gathered if they were a nn.Partitioned object.

Return type:

xlstm_jax.common_types.PyTree

xlstm_jax.distributed.data_parallel.shard_module_params(target, axis_name, min_weight_size=2**18, gather_dtype=None, grad_scatter_dtype=None)#

Shard parameters of a module across replicas.

Parameters:
  • target (flax.linen.Module | collections.abc.Callable) – The module to shard.

  • axis_name (str) – The axis name to shard parameters across.

  • min_weight_size (int) – The minimum size of a parameter to shard. Parameters with fewer values will not be sharded.

  • gather_dtype (jax.numpy.dtype | None) – The dtype to cast the parameters to before gathering. If None, no casting is performed.

  • grad_scatter_dtype (jax.numpy.dtype | None) – The dtype to cast the gradients to before scattering. If None, the dtype of the parameters is used.

Returns:

The module with sharded parameters.

Return type:

flax.linen.Module | collections.abc.Callable

xlstm_jax.distributed.data_parallel.sync_gradients(grads, axis_names)#

Synchronize gradients across devices.

Gradients for parameters that are replicated over a given axis are averaged across devices. Parameters that are partitioned over a given axis are considered to already have a mean of the gradients on each device, and hence do not need to be altered.

Parameters:
  • grads (xlstm_jax.common_types.PyTree) – The gradients to synchronize.

  • axis_names (collections.abc.Sequence[str]) – The axis names to synchronize gradients across.

Returns:

The gradients averaged over the specified axes if they are replicated.

Return type:

xlstm_jax.common_types.PyTree