xlstm_jax.distributed.single_gpu#
Functions#
|
Calculate gradients and metrics for a batch using gradient accumulation. |
|
Calculate gradients and metrics for a batch using gradient accumulation. |
|
Calculate gradients and metrics for a batch using gradient accumulation. |
Module Contents#
- xlstm_jax.distributed.single_gpu.accumulate_gradients_loop(state, batch, rng, num_minibatches, loss_fn)#
Calculate gradients and metrics for a batch using gradient accumulation.
- Parameters:
state (xlstm_jax.common_types.TrainState) – Current training state.
batch (xlstm_jax.dataset.Batch) – Full training batch.
rng (xlstm_jax.common_types.PRNGKeyArray) – Random number generator to use.
num_minibatches (int) – Number of mini-batches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn (collections.abc.Callable) – Loss function to calculate gradients and metrics.
- Returns:
Tuple with accumulated gradients, metrics, and collected mutable variables over the mini-batches.
- Return type:
tuple[xlstm_jax.common_types.PyTree, xlstm_jax.common_types.Metrics, collections.abc.Sequence[xlstm_jax.common_types.PyTree]]
- xlstm_jax.distributed.single_gpu.accumulate_gradients_scan(state, batch, rng, num_minibatches, loss_fn)#
Calculate gradients and metrics for a batch using gradient accumulation.
In this version, we use jax.lax.scan to loop over the mini-batches. This is more efficient in terms of compilation time.
- Parameters:
state (xlstm_jax.common_types.TrainState) – Current training state.
batch (xlstm_jax.dataset.Batch) – Full training batch.
rng (xlstm_jax.common_types.PRNGKeyArray) – Random number generator to use.
num_minibatches (int) – Number of mini-batches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn (collections.abc.Callable) – Loss function to calculate gradients and metrics.
- Returns:
Tuple with accumulated gradients, metrics, and collected mutable variables over the mini-batches.
- Return type:
tuple[xlstm_jax.common_types.PyTree, xlstm_jax.common_types.Metrics, xlstm_jax.common_types.PyTree]
- xlstm_jax.distributed.single_gpu.accumulate_gradients(state, batch, rng, num_minibatches, loss_fn, use_scan=False)#
Calculate gradients and metrics for a batch using gradient accumulation.
This function supports scanning over the mini-batches using jax.lax.scan or using a for loop.
- Parameters:
state (xlstm_jax.common_types.TrainState) – Current training state.
batch (xlstm_jax.dataset.Batch) – Full training batch.
rng (xlstm_jax.common_types.PRNGKeyArray) – Random number generator to use.
num_minibatches (int) – Number of mini-batches to split the batch into. Equal to the number of gradient accumulation steps.
loss_fn (collections.abc.Callable) – Loss function to calculate gradients and metrics.
use_scan (bool) – Whether to use jax.lax.scan for looping over the mini-batches.
- Returns:
Tuple with accumulated gradients, metrics, and collected mutable variables over the mini-batches.
- Return type:
tuple[xlstm_jax.common_types.PyTree, xlstm_jax.common_types.Metrics, collections.abc.Sequence[xlstm_jax.common_types.PyTree] | xlstm_jax.common_types.PyTree]