xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_fw#
Triton backend for the backward pass of the mLSTM chunkwise formulation.
This file has been adapted from the original PyTorch Triton implementation to JAX. For the Triton kernels, see mlstm_kernels/mlstm_kernels/mlstm/chunkwise/triton_fwbw_stablef.py.
In this file, we use the following notation:
- Dimensions:
B: batch size H: number of heads S: sequence length (K, V) T: sequence length (Q) K: hidden dimension (Q, K) V: hidden dimension (H, V) NT: number of chunks BT: chunk size
Functions#
|
|
|
Execute the recurrent forward kernel for the C computation in the mLSTM chunkwise formulation. |
|
Execute the parallel forward kernel for the H computation in the mLSTM chunkwise formulation. |
|
Execute the forward pass of the mLSTM chunkwise formulation. |
Module Contents#
- xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_fw.assert_equal(a, b)#
- xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_fw._mlstm_chunkwise__recurrent_fw_C(matK, matV, vecF, vecI, matC_initial=None, vecN_initial=None, scaMinter_initial=None, chunk_size=64, num_chunks=1, store_final_state=False)#
Execute the recurrent forward kernel for the C computation in the mLSTM chunkwise formulation.
This function defines the grid and block sizes for the kernel launch and calls the kernel. See the fwbw backend implementation and the triton kernels for more information.
- Parameters:
matK (jax.Array) – Tensor containing the keys. Shape (B, H, S, K).
matV (jax.Array) – Tensor containing the values. Shape (B, H, S, V).
vecB – Tensor containing the summed log forget gate activations. Shape (B, H, NT, BT).
vecI (jax.Array) – Tensor containing the input gate. Shape (B, H, NT, BT).
matC_states – Buffer for the states of the C matrix. Shape (B, H, NT * K, V). Defaults to None.
vecN_states – Buffer for the states of the N vector. Shape (B, H, NT * K). Defaults to None.
scaMinter_states – Buffer for the states of the M scalar. Shape (B, H, (NT + 1)). Defaults to None.
matC_initial (jax.Array | None) – Initial state of the C matrix. Shape (B, H, K, V). Defaults to None.
vecN_initial (jax.Array | None) – Initial state of the N vector. Shape (B, H, K). Defaults to None.
scaMinter_initial (jax.Array | None) – Initial state of the M scalar. Shape (B, H). Defaults to None.
qk_scale – Scaling factor for the QK matrix. Defaults to None and will be inferred.
chunk_size (int) – Chunk size for the kernel. Defaults to 64.
num_chunks (int) – Number of chunks. Defaults to 1.
store_final_state (bool) – Whether to return the final state
vecF (jax.Array)
- Returns:
Tuple containing the states of the C matrix, the N vector and the M scalar.
- Return type:
tuple[jax.Array, jax.Array, jax.Array] | tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]
- xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_fw._mlstm_chunkwise__parallel_fw_H(matQ, matK, matV, matC_states, vecN_states, scaMinter_states, vecI, vecF, qk_scale=None, chunk_size=64, num_chunks=1, eps=1e-06, stabilize_correctly=True, norm_val=1.0)#
Execute the parallel forward kernel for the H computation in the mLSTM chunkwise formulation.
This function defines the grid and block sizes for the kernel launch and calls the kernel. See the fwbw backend implementation and the triton kernels for more information.
- Parameters:
matQ (jax.Array) – Tensor containing the queries. Shape (B, H, S, K).
matK (jax.Array) – Tensor containing the keys. Shape (B, H, S, K).
matV (jax.Array) – Tensor containing the values. Shape (B, H, S, V).
matC_states (jax.Array) – States of the C matrix. Shape (B, H, NT * K, V). This state and following states must be all states up to the last chunk, i.e. :-1.
vecN_states (jax.Array) – States of the N vector. Shape (B, H, NT * K).
scaMinter_states (jax.Array) – States of the M scalar. Shape (B, H, NT + 1).
vecI (jax.Array) – Tensor containing the input gate. Shape (B, H, NT, BT).
vecB – Tensor containing the summed log forget gate activations. Shape (B, H, NT, BT).
qk_scale (float | None) – Scaling factor for the QK matrix. Defaults to None and will be inferred.
CHUNK_SIZE – Chunk size for the kernel. Defaults to 64.
NUM_CHUNKS – Number of chunks. Defaults to 1.
EPS – Small value to avoid division by zero. Defaults to 1e-6.
vecF (jax.Array)
chunk_size (int)
num_chunks (int)
eps (float)
stabilize_correctly (bool)
norm_val (float)
- Returns:
Tuple containing the output matrix H (shape (B, H, S, V)) and the N vector (shape (B, H, S)).
- Return type:
- xlstm_jax.kernels.mlstm_chunkwise.triton_stablef._triton_fw._mlstm_chunkwise_fw(matQ, matK, matV, vecI, vecF, matC_initial=None, vecN_initial=None, scaM_initial=None, qk_scale=None, return_last_states=False, return_all_states=False, chunk_size=64, stabilize_correctly=True, norm_val=1.0, eps=1e-06)#
Execute the forward pass of the mLSTM chunkwise formulation.
- Parameters:
matQ (jax.Array) – Tensor containing the queries. Shape (B, H, S, K).
matK (jax.Array) – Tensor containing the keys. Shape (B, H, S, K).
matV (jax.Array) – Tensor containing the values. Shape (B, H, S, V).
vecI (jax.Array) – Tensor containing the input gate. Shape (B, H, S).
vecF (jax.Array) – Tensor containing the forget gate. Shape (B, H, S).
matC_initial (jax.Array | None) – Initial state of the C matrix. Shape (B, H, K, V). Defaults to None.
vecN_initial (jax.Array | None) – Initial state of the N vector. Shape (B, H, K). Defaults to None.
scaM_initial (jax.Array | None) – Initial state of the M scalar. Shape (B, H). Defaults to None.
qk_scale (float) – Scaling factor for the QK matrix. Defaults to None and will be inferred.
return_last_states (bool) – Whether to return the last states. Defaults to False.
return_all_states (bool) – Whether to return all states. Defaults to False.
chunk_size (int) – Chunk size for the kernel. Defaults to 64.
stabilize_correctly (bool) – Whether to stabilize with max(norm_val*e^-m, |nk|) instead of max(norm_val, |nk|)
norm_val (float) – Norm scale in the max formula above.
eps (float)
- Returns:
Tuple containing the output matrix H (shape (B, H, S, V)), the N vector (shape (B, H, S)), the M scalar (shape (B, H)). Optionally, it might contain last states (matC_states, vecN_states, scaMinter_states) and optional all states (matC_states, vecN_states, scaMinter_states).
- Return type:
tuple[jax.Array, jax.Array, jax.Array, None | tuple[jax.Array, jax.Array, jax.Array], None | tuple[jax.Array, jax.Array, jax.Array]]