xlstm_jax.distributed.pipeline_parallel#

Classes#

PipelineModule

Module wrapper for executing a pipeline of stages.

Functions#

execute_pipeline_step(module, state, input, *args, ...)

Single micro-batch pipeline step.

execute_pipeline(module, x, *args, num_microbatches, ...)

Execute a pipeline of stages on a batch of data.

Module Contents#

xlstm_jax.distributed.pipeline_parallel.execute_pipeline_step(module, state, input, *args, model_axis_name, **kwargs)#

Single micro-batch pipeline step.

Parameters:
  • module (flax.linen.Module) – Flax module representing the stage to execute.

  • state (jax.Array) – Last communicated features between stages. Used as input to the module for all stages except the first.

  • input (jax.Array) – Original micro-batch input to the pipeline stage. Used as input to the module for the first stage.

  • *args – Additional arguments to the module.

  • model_axis_name (str) – Name of the model axis in the mesh/shard_map.

  • **kwargs – Additional keyword arguments to the module.

Returns:

Tuple of the new state (after communication) and the output of the module.

Return type:

tuple[jax.Array, jax.Array]

xlstm_jax.distributed.pipeline_parallel.execute_pipeline(module, x, *args, num_microbatches, model_axis_name, **kwargs)#

Execute a pipeline of stages on a batch of data.

Uses the principle of GPipe in splitting the batch into micro-batches and running the pipeline stages in parallel.

Parameters:
  • module (flax.linen.Module) – Flax module representing the pipeline stage to execute.

  • x (jax.Array) – Batch of input data, only needed on device of the first stage. Data will be split into micro-batches.

  • *args – Additional arguments to the module.

  • num_microbatches (int) – Number of micro-batches to split the batch into.

  • model_axis_name (str) – Name of the model axis in the mesh/shard_map.

  • **kwargs – Additional keyword arguments to the module.

Returns:

Output of the last stage of the pipeline. For devices that are not the last stage, the output is zeros.

Return type:

jax.Array

class xlstm_jax.distributed.pipeline_parallel.PipelineModule#

Bases: flax.linen.Module

Module wrapper for executing a pipeline of stages.

This module is used to wrap a stage of a pipeline to execute in pipeline parallelism.

Parameters:
  • model_axis_name – Name of the model axis in the mesh/shard_map.

  • num_microbatches – Number of micro-batches to split the batch into.

  • module_fn – Function that returns the module to execute in the pipeline.

model_axis_name: str#
num_microbatches: int#
module_fn: collections.abc.Callable[Ellipsis, flax.linen.Module]#