xlstm_jax.distributed.tensor_parallel#
Classes#
Wrapper for adding model parallelism to a module. |
|
Dense layer with Tensor Parallelism support. |
|
Tensor-Parallel Dense Layer with Asynchronous Communication. |
Functions#
|
Scales the output of the given init function by the given factor. |
|
All gather using ring permutation. |
|
All gather using ring permutation with bidirectional communication. |
|
All gather using ring permutation with features split for bidirectional communication. |
|
Scatter sum using ring permutation. |
|
Scatter sum using ring permutation with features split for bidirectional communication. |
Module Contents#
- class xlstm_jax.distributed.tensor_parallel.ModelParallelismWrapper#
Bases:
flax.linen.ModuleWrapper for adding model parallelism to a module.
This wrapper adds sharding over the model axis to the parameters of the module and initializes the module with different parameters across the model axis.
- Parameters:
model_axis_name – Name of the model axis to shard over.
module_fn – Function that returns the Flax module to wrap.
mask_except_model_idx – If not None, only the mask_except_model_idx-th shard will be non-zero.
split_rngs – If True, split the random number generators across the model axis.
module_kwargs – Additional keyword arguments to pass to the module function.
- module_fn: collections.abc.Callable[Ellipsis, flax.linen.Module]#
- module_kwargs: flax.core.frozen_dict.FrozenDict[str, Any]#
- xlstm_jax.distributed.tensor_parallel.scale_init(init_fn, scale_factor=1.0)#
Scales the output of the given init function by the given factor.
- Parameters:
init_fn (collections.abc.Callable) – The init function to scale.
scale_factor (float) – The factor to scale the output of the init function by.
- Returns:
A new init function that scales the output of the given init function by the given factor.
- class xlstm_jax.distributed.tensor_parallel.TPDense#
Bases:
flax.linen.ModuleDense layer with Tensor Parallelism support.
This layer can be used to perform a dense layer with Tensor Parallelism support.
- dense_fn#
Constructor function of the dense layer to use. Needs to support the keyword argument kernel_init.
- model_axis_name#
The name of the model axis.
- tp_mode#
The Tensor Parallelism mode to use. Can be “scatter”, “gather”, or “none”.
- skip_communication#
Whether to skip communication in the Tensor Parallelism strategy. Useful for layers with custom communication or where input has been already gathered beforehand.
- kernel_init#
The initializer to use for the kernel of the dense layer.
- kernel_init_adjustment#
The adjustment factor to use for the kernel initializer.
- use_bias#
Whether to use a bias in the dense layer.
- dense_name#
The name of the dense layer module.
- dense_fn: Any#
- tp_mode: Literal['scatter', 'gather', 'none'] = 'none'#
- kernel_init: collections.abc.Callable#
- xlstm_jax.distributed.tensor_parallel.async_gather(x, axis_name, shift_up=True)#
All gather using ring permutation.
- Parameters:
- Returns:
List of gathered inputs.
- Return type:
list[xlstm_jax.common_types.PyTree]
- xlstm_jax.distributed.tensor_parallel.async_gather_bidirectional(x, axis_name, shift_up=True)#
All gather using ring permutation with bidirectional communication.
- Parameters:
- Returns:
List of gathered inputs.
- Return type:
- xlstm_jax.distributed.tensor_parallel.async_gather_split(x, axis_name)#
All gather using ring permutation with features split for bidirectional communication.
- xlstm_jax.distributed.tensor_parallel.async_scatter(xs, axis_name, shift_up=True)#
Scatter sum using ring permutation.
- Parameters:
xs (collections.abc.Sequence[xlstm_jax.common_types.PyTree]) – The inputs to scatter sum. The length of the list should match the size of the axis.
axis_name (str) – The axis name to scatter sum along.
shift_up (bool) – Whether to shift up (device 0 send to device 1) or down (device 1 send to device 0).
- Returns:
The scatter summed output.
- Return type:
xlstm_jax.common_types.PyTree
- xlstm_jax.distributed.tensor_parallel.async_scatter_split(xs, axis_name)#
Scatter sum using ring permutation with features split for bidirectional communication.
- Parameters:
xs (collections.abc.Sequence[xlstm_jax.common_types.PyTree]) – The inputs to scatter sum. The length of the list should match the size of the axis.
axis_name (str) – The axis name to scatter sum along.
- Returns:
The scatter summed output.
- Return type:
xlstm_jax.common_types.PyTree
- class xlstm_jax.distributed.tensor_parallel.TPAsyncDense#
Bases:
flax.linen.ModuleTensor-Parallel Dense Layer with Asynchronous Communication.
This layer can be used to perform a dense layer with Tensor Parallelism support, and overlaps communication with computation whenever possible.
- dense_fn#
Constructor function of the dense layer to use. Needs to support the keyword argument kernel_init.
- model_axis_name#
The name of the model axis.
- tp_mode#
The Tensor Parallelism mode to use. Can be “scatter”, “gather”, or “none”.
- kernel_init#
The initializer to use for the kernel of the dense layer.
- kernel_init_adjustment#
The adjustment factor to use for the kernel initializer.
- use_bias#
Whether to use a bias in the dense layer.
- dense_name#
The name of the dense layer module.
- use_bidirectional_gather#
Whether to use bidirectional or unidirectional gather over the device ring for communication.
- use_bidirectional_scatter#
Whether to use bidirectional or unidirectional scatter over the device ring for communication.
- dense_fn: Any#
- tp_mode: Literal['scatter', 'gather', 'none'] = 'none'#
- kernel_init: collections.abc.Callable#