xlstm_jax.distributed.single_gpu#

Functions#

accumulate_gradients_loop(state, batch, rng, ...)

Calculate gradients and metrics for a batch using gradient accumulation.

accumulate_gradients_scan(state, batch, rng, ...)

Calculate gradients and metrics for a batch using gradient accumulation.

accumulate_gradients(state, batch, rng, ...[, use_scan])

Calculate gradients and metrics for a batch using gradient accumulation.

Module Contents#

xlstm_jax.distributed.single_gpu.accumulate_gradients_loop(state, batch, rng, num_minibatches, loss_fn)#

Calculate gradients and metrics for a batch using gradient accumulation.

Parameters:
  • state (xlstm_jax.common_types.TrainState) – Current training state.

  • batch (xlstm_jax.dataset.Batch) – Full training batch.

  • rng (xlstm_jax.common_types.PRNGKeyArray) – Random number generator to use.

  • num_minibatches (int) – Number of mini-batches to split the batch into. Equal to the number of gradient accumulation steps.

  • loss_fn (collections.abc.Callable) – Loss function to calculate gradients and metrics.

Returns:

Tuple with accumulated gradients, metrics, and collected mutable variables over the mini-batches.

Return type:

tuple[xlstm_jax.common_types.PyTree, xlstm_jax.common_types.Metrics, collections.abc.Sequence[xlstm_jax.common_types.PyTree]]

xlstm_jax.distributed.single_gpu.accumulate_gradients_scan(state, batch, rng, num_minibatches, loss_fn)#

Calculate gradients and metrics for a batch using gradient accumulation.

In this version, we use jax.lax.scan to loop over the mini-batches. This is more efficient in terms of compilation time.

Parameters:
  • state (xlstm_jax.common_types.TrainState) – Current training state.

  • batch (xlstm_jax.dataset.Batch) – Full training batch.

  • rng (xlstm_jax.common_types.PRNGKeyArray) – Random number generator to use.

  • num_minibatches (int) – Number of mini-batches to split the batch into. Equal to the number of gradient accumulation steps.

  • loss_fn (collections.abc.Callable) – Loss function to calculate gradients and metrics.

Returns:

Tuple with accumulated gradients, metrics, and collected mutable variables over the mini-batches.

Return type:

tuple[xlstm_jax.common_types.PyTree, xlstm_jax.common_types.Metrics, xlstm_jax.common_types.PyTree]

xlstm_jax.distributed.single_gpu.accumulate_gradients(state, batch, rng, num_minibatches, loss_fn, use_scan=False)#

Calculate gradients and metrics for a batch using gradient accumulation.

This function supports scanning over the mini-batches using jax.lax.scan or using a for loop.

Parameters:
  • state (xlstm_jax.common_types.TrainState) – Current training state.

  • batch (xlstm_jax.dataset.Batch) – Full training batch.

  • rng (xlstm_jax.common_types.PRNGKeyArray) – Random number generator to use.

  • num_minibatches (int) – Number of mini-batches to split the batch into. Equal to the number of gradient accumulation steps.

  • loss_fn (collections.abc.Callable) – Loss function to calculate gradients and metrics.

  • use_scan (bool) – Whether to use jax.lax.scan for looping over the mini-batches.

Returns:

Tuple with accumulated gradients, metrics, and collected mutable variables over the mini-batches.

Return type:

tuple[xlstm_jax.common_types.PyTree, xlstm_jax.common_types.Metrics, collections.abc.Sequence[xlstm_jax.common_types.PyTree] | xlstm_jax.common_types.PyTree]