xlstm_jax.distributed.tensor_parallel#

Classes#

ModelParallelismWrapper

Wrapper for adding model parallelism to a module.

TPDense

Dense layer with Tensor Parallelism support.

TPAsyncDense

Tensor-Parallel Dense Layer with Asynchronous Communication.

Functions#

scale_init(init_fn[, scale_factor])

Scales the output of the given init function by the given factor.

async_gather(x, axis_name[, shift_up])

All gather using ring permutation.

async_gather_bidirectional(x, axis_name[, shift_up])

All gather using ring permutation with bidirectional communication.

async_gather_split(x, axis_name)

All gather using ring permutation with features split for bidirectional communication.

async_scatter(xs, axis_name[, shift_up])

Scatter sum using ring permutation.

async_scatter_split(xs, axis_name)

Scatter sum using ring permutation with features split for bidirectional communication.

Module Contents#

class xlstm_jax.distributed.tensor_parallel.ModelParallelismWrapper#

Bases: flax.linen.Module

Wrapper for adding model parallelism to a module.

This wrapper adds sharding over the model axis to the parameters of the module and initializes the module with different parameters across the model axis.

Parameters:
  • model_axis_name – Name of the model axis to shard over.

  • module_fn – Function that returns the Flax module to wrap.

  • mask_except_model_idx – If not None, only the mask_except_model_idx-th shard will be non-zero.

  • split_rngs – If True, split the random number generators across the model axis.

  • module_kwargs – Additional keyword arguments to pass to the module function.

model_axis_name: str#
module_fn: collections.abc.Callable[Ellipsis, flax.linen.Module]#
mask_except_model_idx: int | None = None#
split_rngs: bool = True#
module_kwargs: flax.core.frozen_dict.FrozenDict[str, Any]#
xlstm_jax.distributed.tensor_parallel.scale_init(init_fn, scale_factor=1.0)#

Scales the output of the given init function by the given factor.

Parameters:
  • init_fn (collections.abc.Callable) – The init function to scale.

  • scale_factor (float) – The factor to scale the output of the init function by.

Returns:

A new init function that scales the output of the given init function by the given factor.

class xlstm_jax.distributed.tensor_parallel.TPDense#

Bases: flax.linen.Module

Dense layer with Tensor Parallelism support.

This layer can be used to perform a dense layer with Tensor Parallelism support.

dense_fn#

Constructor function of the dense layer to use. Needs to support the keyword argument kernel_init.

model_axis_name#

The name of the model axis.

tp_mode#

The Tensor Parallelism mode to use. Can be “scatter”, “gather”, or “none”.

skip_communication#

Whether to skip communication in the Tensor Parallelism strategy. Useful for layers with custom communication or where input has been already gathered beforehand.

kernel_init#

The initializer to use for the kernel of the dense layer.

kernel_init_adjustment#

The adjustment factor to use for the kernel initializer.

use_bias#

Whether to use a bias in the dense layer.

dense_name#

The name of the dense layer module.

dense_fn: Any#
model_axis_name: str#
tp_mode: Literal['scatter', 'gather', 'none'] = 'none'#
skip_communication: bool = False#
kernel_init: collections.abc.Callable#
kernel_init_adjustment: float = 1.0#
use_bias: bool = True#
dense_name: str = 'module'#
xlstm_jax.distributed.tensor_parallel.async_gather(x, axis_name, shift_up=True)#

All gather using ring permutation.

Parameters:
  • x (xlstm_jax.common_types.PyTree) – The input to gather.

  • axis_name (str) – The axis name to gather along.

  • shift_up (bool) – Whether to shift up (device 0 send to device 1) or down (device 1 send to device 0).

Returns:

List of gathered inputs.

Return type:

list[xlstm_jax.common_types.PyTree]

xlstm_jax.distributed.tensor_parallel.async_gather_bidirectional(x, axis_name, shift_up=True)#

All gather using ring permutation with bidirectional communication.

Parameters:
  • x (jax.Array) – The input to gather.

  • axis_name (str) – The axis name to gather along.

  • shift_up (bool) – Whether to return the order of tensors that complies with the unidirectional version of shift up (device 0 send to device 1) or down (device 1 send to device 0).

Returns:

List of gathered inputs.

Return type:

list[jax.Array]

xlstm_jax.distributed.tensor_parallel.async_gather_split(x, axis_name)#

All gather using ring permutation with features split for bidirectional communication.

Parameters:
  • x (jax.Array) – The input to gather.

  • axis_name (str) – The axis name to gather along.

Returns:

List of gathered inputs. Length is 2 * axis size - 1.

Return type:

list[jax.Array]

xlstm_jax.distributed.tensor_parallel.async_scatter(xs, axis_name, shift_up=True)#

Scatter sum using ring permutation.

Parameters:
  • xs (collections.abc.Sequence[xlstm_jax.common_types.PyTree]) – The inputs to scatter sum. The length of the list should match the size of the axis.

  • axis_name (str) – The axis name to scatter sum along.

  • shift_up (bool) – Whether to shift up (device 0 send to device 1) or down (device 1 send to device 0).

Returns:

The scatter summed output.

Return type:

xlstm_jax.common_types.PyTree

xlstm_jax.distributed.tensor_parallel.async_scatter_split(xs, axis_name)#

Scatter sum using ring permutation with features split for bidirectional communication.

Parameters:
  • xs (collections.abc.Sequence[xlstm_jax.common_types.PyTree]) – The inputs to scatter sum. The length of the list should match the size of the axis.

  • axis_name (str) – The axis name to scatter sum along.

Returns:

The scatter summed output.

Return type:

xlstm_jax.common_types.PyTree

class xlstm_jax.distributed.tensor_parallel.TPAsyncDense#

Bases: flax.linen.Module

Tensor-Parallel Dense Layer with Asynchronous Communication.

This layer can be used to perform a dense layer with Tensor Parallelism support, and overlaps communication with computation whenever possible.

dense_fn#

Constructor function of the dense layer to use. Needs to support the keyword argument kernel_init.

model_axis_name#

The name of the model axis.

tp_mode#

The Tensor Parallelism mode to use. Can be “scatter”, “gather”, or “none”.

kernel_init#

The initializer to use for the kernel of the dense layer.

kernel_init_adjustment#

The adjustment factor to use for the kernel initializer.

use_bias#

Whether to use a bias in the dense layer.

dense_name#

The name of the dense layer module.

use_bidirectional_gather#

Whether to use bidirectional or unidirectional gather over the device ring for communication.

use_bidirectional_scatter#

Whether to use bidirectional or unidirectional scatter over the device ring for communication.

dense_fn: Any#
model_axis_name: str#
tp_mode: Literal['scatter', 'gather', 'none'] = 'none'#
kernel_init: collections.abc.Callable#
kernel_init_adjustment: float = 1.0#
use_bias: bool = True#
dense_name: str = 'module'#
use_bidirectional_gather: bool = True#
use_bidirectional_scatter: bool = False#