xlstm_jax.distributed.array_utils#
Functions#
|
Folds the random number generator over the given axis. |
|
Split an array over the given mesh axis. |
|
Stacks sharded parameters along a given axis name. |
|
Unstacks parameters along a given axis name. |
Module Contents#
- xlstm_jax.distributed.array_utils.fold_rng_over_axis(rng, axis_name)#
Folds the random number generator over the given axis.
This is useful for generating a different random number for each device across a certain axis (e.g. the model axis).
- Parameters:
rng (xlstm_jax.common_types.PRNGKeyArray) – The random number generator.
axis_name (str) – The axis name to fold the random number generator over.
- Returns:
A new random number generator, different for each device index along the axis.
- Return type:
xlstm_jax.common_types.PRNGKeyArray
- xlstm_jax.distributed.array_utils.split_array_over_mesh(x, axis_name, split_axis)#
Split an array over the given mesh axis.
- xlstm_jax.distributed.array_utils.stack_params(params, axis_name, axis=0, mask_except=None)#
Stacks sharded parameters along a given axis name.
- Parameters:
- Returns:
PyTree of parameters with the same structure as params, but with the leaf nodes replaced by nn.Partitioned objects with sharding over axis name added to axis-th axis of parameters.
- Return type:
xlstm_jax.common_types.PyTree
- xlstm_jax.distributed.array_utils.unstack_params(params, axis_name)#
Unstacks parameters along a given axis name.
Inverse operation to stack_params.
- Parameters:
params (xlstm_jax.common_types.PyTree) – PyTree of parameters.
axis_name (str) – Name of the axis to unstack along.
- Returns:
PyTree of parameters with the same structure as params, but with the leaf nodes having the sharding over the axis name removed.
- Return type:
xlstm_jax.common_types.PyTree