xlstm_jax.distributed.pipeline_parallel#
Classes#
Module wrapper for executing a pipeline of stages. |
Functions#
|
Single micro-batch pipeline step. |
|
Execute a pipeline of stages on a batch of data. |
Module Contents#
- xlstm_jax.distributed.pipeline_parallel.execute_pipeline_step(module, state, input, *args, model_axis_name, **kwargs)#
Single micro-batch pipeline step.
- Parameters:
module (flax.linen.Module) – Flax module representing the stage to execute.
state (jax.Array) – Last communicated features between stages. Used as input to the module for all stages except the first.
input (jax.Array) – Original micro-batch input to the pipeline stage. Used as input to the module for the first stage.
*args – Additional arguments to the module.
model_axis_name (str) – Name of the model axis in the mesh/shard_map.
**kwargs – Additional keyword arguments to the module.
- Returns:
Tuple of the new state (after communication) and the output of the module.
- Return type:
- xlstm_jax.distributed.pipeline_parallel.execute_pipeline(module, x, *args, num_microbatches, model_axis_name, **kwargs)#
Execute a pipeline of stages on a batch of data.
Uses the principle of GPipe in splitting the batch into micro-batches and running the pipeline stages in parallel.
- Parameters:
module (flax.linen.Module) – Flax module representing the pipeline stage to execute.
x (jax.Array) – Batch of input data, only needed on device of the first stage. Data will be split into micro-batches.
*args – Additional arguments to the module.
num_microbatches (int) – Number of micro-batches to split the batch into.
model_axis_name (str) – Name of the model axis in the mesh/shard_map.
**kwargs – Additional keyword arguments to the module.
- Returns:
Output of the last stage of the pipeline. For devices that are not the last stage, the output is zeros.
- Return type:
- class xlstm_jax.distributed.pipeline_parallel.PipelineModule#
Bases:
flax.linen.ModuleModule wrapper for executing a pipeline of stages.
This module is used to wrap a stage of a pipeline to execute in pipeline parallelism.
- Parameters:
model_axis_name – Name of the model axis in the mesh/shard_map.
num_microbatches – Number of micro-batches to split the batch into.
module_fn – Function that returns the module to execute in the pipeline.
- module_fn: collections.abc.Callable[Ellipsis, flax.linen.Module]#