xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_bw#

Triton backend for the backward pass of the mLSTM chunkwise formulation.

This file has been adapted from the original PyTorch Triton implementation to JAX. For the Triton kernels, see mlstm_kernels/mlstm_kernels/mlstm/chunkwise/triton_fwbw_stablef.py.

In this file, we use the following notation:

Dimensions:

B: batch size H: number of heads S: sequence length (K, V) T: sequence length (Q) K: hidden dimension (Q, K) V: hidden dimension (H, V) NT: number of chunks BT: chunk size

Functions#

_mlstm_chunkwise__recurrent_bw_dC(matQ, vecF, ...[, ...])

Computes only the deltaC gradients for the backward pass.

_mlstm_chunkwise__parallel_bw_dQKV(matQ, matK, matV, ...)

Computes the gradients for the query, key and value matrices.

_mlstm_chunkwise_bw(matQ, matK, matV, vecI, vecF[, ...])

Computes the backward pass of the mLSTM chunkwise formulation.

Module Contents#

xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_bw._mlstm_chunkwise__recurrent_bw_dC(matQ, vecF, scaM_inter, vecM, vecN_out, matDeltaH, matDeltaC_last=None, qk_scale=None, chunk_size=64, num_chunks=1, store_initial_state=False)#

Computes only the deltaC gradients for the backward pass.

The other gradients are computed in the other (kernel) function. We do not need to compute the gradients for the denominator, as it cancels out in the forward in the groupnorm.

Parameters:
  • matQ (jax.Array) – Tensor containing the query vectors. Shape (B, H, T, K).

  • vecF (jax.Array) – Tensor containing the log forget gate activations. Shape (B, H, NT, BT).

  • scaM_inter (jax.Array) – States of the M scalar. Shape (B, H, NT+1).

  • vecM (jax.Array) – M states. Shape (B, H, T).

  • matDeltaH (jax.Array) – Tensor containing the H gradients. Shape (B, H, T, V).

  • vecN_out (jax.Array) – States of the N vector. Shape (B, H, NT * DHQK).

  • matDeltaC_last (jax.Array | None) – Tensor containing the last C gradients. Shape (B, H, DHQK, DHHV). Defaults to None.

  • qk_scale (float | None) – Scale factor for the QK matrix. Defaults to None.

  • chunk_size (int) – Chunk size. Defaults to 64.

  • num_chunks (int) – Number of chunks. Defaults to 1.

  • store_initial_state (bool) – Whether to store the inital state gradient and logscale (m state)

Returns:

Tensor containing the C gradients and the C_first gradients. Shapes (B, H, NT * DHQK, DHHV), (B, H, DHQK, DHHV).

Return type:

tuple[jax.Array, jax.Array | None, jax.Array | None]

xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_bw._mlstm_chunkwise__parallel_bw_dQKV(matQ, matK, matV, vecI, vecF, vecM_combine, scaM_inter, matC_states, matDeltaH, vecN_out, matDeltaC_states, qk_scale=None, chunk_size=64, num_chunks=1)#

Computes the gradients for the query, key and value matrices.

Parameters:
  • matQ (jax.Array) – Tensor containing the query vectors. Shape (B, H, T, K).

  • matK (jax.Array) – Tensor containing the key vectors. Shape (B, H, T, K).

  • matV (jax.Array) – Tensor containing the value vectors. Shape (B, H, T, V).

  • vecF (jax.Array) – Tensor containing the summed log forget gate activations. Shape (B, H, NT, BT).

  • vecI (jax.Array) – Tensor containing the input gate pre-activations. Shape (B, H, NT, BT).

  • vecM_combine (jax.Array) – Combined M states. Shape (B, H, T).

  • scaM_inter (jax.Array) – States of the M scalar. Shape (B, H, NT+1).

  • matC_states (jax.Array) – States of the C matrix. Shape (B, H, NT * DHQK, DHHV).

  • matDeltaH (jax.Array) – Tensor containing the H gradients. Shape (B, H, T, V).

  • vecN_out (jax.Array) – States of the N vector. Shape (B, H, T).

  • matDeltaC_states (jax.Array) – Tensor containing the C gradients. Shape (B, H, (NT+1) * DHQK, DHHV).

  • qk_scale (float | None) – Scale factor for the QK matrix. Defaults to None.

  • chunk_size (int, optional) – Chunk size. Defaults to 64.

  • num_chunks (int, optional) – Number of chunks. Defaults to 1.

Returns:

Gradients for the query, key and value matrices. Shapes (B, H, T, K), (B, H, T, K), (B, H, T, V).

Return type:

tuple[jax.Array, jax.Array, jax.Array]

xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_bw._mlstm_chunkwise_bw(matQ, matK, matV, vecI, vecF, matC_initial=None, vecN_initial=None, scaM_initial=None, qk_scale=None, matC_states=None, scaM_states=None, vecN_out=None, vecM_out=None, matDeltaH=None, matDeltaC_last=None, chunk_size=64)#

Computes the backward pass of the mLSTM chunkwise formulation.

Parameters:
  • matQ (jax.Array) – Tensor containing the query vectors. Shape (B, H, T, K).

  • matK (jax.Array) – Tensor containing the key vectors. Shape (B, H, T, K).

  • matV (jax.Array) – Tensor containing the value vectors. Shape (B, H, S, DHV).

  • vecI (jax.Array) – Tensor containing the input gate pre-activations. Shape (B, H, T).

  • vecF (jax.Array) – Tensor containing the forget gate pre-activations. Shape (B, H, T).

  • matC_initial (jax.Array | None) – Tensor containing the initial C states. Shape (B, H, DHQK, DHV). Defaults to None.

  • vecN_initial (jax.Array | None) – Tensor containing the initial N states. Shape (B, H, DHQK). Defaults to None.

  • scaM_initial (jax.Array | None) – Tensor containing the initial M states. Shape (B, NH). Defaults to None.

  • qk_scale (float | None) – Scale factor for the QK matrix. Defaults to None.

  • matC_states (jax.Array | None) – Tensor containing all C states. Shape (B, H, NT * DHQK, DHV). Defaults to None.

  • scaM_states (jax.Array | None) – Tensor containing all M states. Shape (B, H, NC). Defaults to None.

  • vecN_out (jax.Array | None) – Tensor containing the N states for the output. Shape (B, H, T). Defaults to None.

  • vecM_out (jax.Array | None) – Tensor containing the M states for the output. Shape (B, H, T). Defaults to None.

  • matDeltaH (jax.Array | None) – Tensor containing the H gradients. Shape (B, H, S, DHV). Defaults to None.

  • matDeltaC_last (jax.Array | None) – Tensor containing the last C gradients. Shape (B, H, DHQK, DHV). Defaults to None.

  • chunk_size (int) – Chunk size. Defaults to 64.

Returns:

Gradients for the query, key, value, vecI and vecF matrices. Shapes (B, H, T, K), (B, H, T, K), (B, H, S, DHV), (B, H, T), (B, H, T). If initial states are provided, the function also returns the gradients for the initial C, N and M states.

Return type:

tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array | None, jax.Array | None, jax.Array | None]