xlstm_jax.distributed.array_utils#

Functions#

fold_rng_over_axis(rng, axis_name)

Folds the random number generator over the given axis.

split_array_over_mesh(x, axis_name, split_axis)

Split an array over the given mesh axis.

stack_params(params, axis_name[, axis, mask_except])

Stacks sharded parameters along a given axis name.

unstack_params(params, axis_name)

Unstacks parameters along a given axis name.

Module Contents#

xlstm_jax.distributed.array_utils.fold_rng_over_axis(rng, axis_name)#

Folds the random number generator over the given axis.

This is useful for generating a different random number for each device across a certain axis (e.g. the model axis).

Parameters:
  • rng (xlstm_jax.common_types.PRNGKeyArray) – The random number generator.

  • axis_name (str) – The axis name to fold the random number generator over.

Returns:

A new random number generator, different for each device index along the axis.

Return type:

xlstm_jax.common_types.PRNGKeyArray

xlstm_jax.distributed.array_utils.split_array_over_mesh(x, axis_name, split_axis)#

Split an array over the given mesh axis.

Parameters:
  • x (jax.Array) – The array to split.

  • axis_name (str) – The axis name of the mesh to split over.

  • split_axis (int) – The axis of the array to split.

Returns:

The slice of the array for the current device along the given axis.

Return type:

jax.Array

xlstm_jax.distributed.array_utils.stack_params(params, axis_name, axis=0, mask_except=None)#

Stacks sharded parameters along a given axis name.

Parameters:
  • params (xlstm_jax.common_types.PyTree) – PyTree of parameters.

  • axis_name (str) – Name of the axis to stack along.

  • axis (int) – Index of the axis to stack along.

  • mask_except (jax.Array | int | None) – If not None, only the mask_except-th shard will be non-zero.

Returns:

PyTree of parameters with the same structure as params, but with the leaf nodes replaced by nn.Partitioned objects with sharding over axis name added to axis-th axis of parameters.

Return type:

xlstm_jax.common_types.PyTree

xlstm_jax.distributed.array_utils.unstack_params(params, axis_name)#

Unstacks parameters along a given axis name.

Inverse operation to stack_params.

Parameters:
  • params (xlstm_jax.common_types.PyTree) – PyTree of parameters.

  • axis_name (str) – Name of the axis to unstack along.

Returns:

PyTree of parameters with the same structure as params, but with the leaf nodes having the sharding over the axis name removed.

Return type:

xlstm_jax.common_types.PyTree