xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw#
Triton backend for the backward pass of the mLSTM chunkwise formulation.
This file has been adapted from the original PyTorch Triton implementation to JAX. For the Triton kernels, see mlstm_kernels/mlstm_kernels/mlstm/chunkwise/triton_fwbw_stablef.py.
In this file, we use the following notation:
- Dimensions:
B: batch size H: number of heads S: sequence length (K, V) T: sequence length (Q) K: hidden dimension (Q, K) V: hidden dimension (H, V) NT: number of chunks BT: chunk size
Functions#
Generate a forward and backward pass function for the mLSTM kernels with chunkwise formulation. |
|
|
Get the forward and backward pass function for the mLSTM kernels with chunkwise formulation. |
|
Apply the mLSTM chunkwise formulation with Triton kernels. |
Module Contents#
- xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw._mlstm_chunkwise_fwbw_generator(autocast_kernel_dtype=jnp.bfloat16, return_last_states=False, recompute_states_in_bw=True, chunk_size=64, eps=1e-06, stabilize_correctly=True, norm_val=1.0)#
Generate a forward and backward pass function for the mLSTM kernels with chunkwise formulation.
- Parameters:
autocast_kernel_dtype (jax.numpy.dtype) – The dtype to use for the kernel computation. All inputs arguments up to vecF are cast to this dtype. vecF is automatically casted to float32 in the kernels.
return_last_states (bool) – Whether to return the last states of the mLSTM.
recompute_states_in_bw (bool) – Whether to recompute the mLSTM states in the backward pass.
chunk_size (int) – The chunk size to use for the mLSTM computation.
eps (float) – The epsilon value to use for numerical stability.
stabilize_correctly (bool) – Whether to stabilize with max(norm_val*e^-m, |nk|) instead of max(norm_val, |nk|)
norm_val (float) – Norm scale in the max formula above
- Returns:
A function that computes the forward pass of the mLSTM chunkwise formulation, which custom gradients for the backward pass. The function input signatures is:
- forward(
matQ: jax.Array, # (B, NH, S, DHQK) matK: jax.Array, # (B, NH, S, DHQK) matV: jax.Array, # (B, NH, S, DHV) vecI: jax.Array, # (B, NH, S) vecF: jax.Array, # (B, NH, S) matC_initial: jax.Array | None = None, # (B, NH, DHQK, DHV) vecN_initial: jax.Array | None = None, # (B, NH, DHQK) scaM_initial: jax.Array | None = None, # (B, NH)
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
The function returns the output of the mLSTM computation, and the last states internal states of C, N and M.
- Return type:
collections.abc.Callable[[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array], tuple[jax.Array, jax.Array, jax.Array, jax.Array]]
- xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw._get_chunkwise_fwbw_kernel(autocast_kernel_dtype, **kwargs)#
Get the forward and backward pass function for the mLSTM kernels with chunkwise formulation.
- Parameters:
autocast_kernel_dtype (jax.numpy.dtype) – The dtype to use for the kernel computation. All inputs arguments up to vecF and vecI are cast to this dtype. vecF is automatically casted to float32 in the kernels.
**kwargs – Additional keyword arguments to pass to the kernel function.
- Returns:
A function that computes the forward pass of the mLSTM chunkwise formulation, which custom gradients for the backward pass. See _mlstm_chunkwise_fwbw_generator for the function signature.
- Return type:
- xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw.mlstm_chunkwise_triton_stablef(q, k, v, i, f, c_initial=None, n_initial=None, m_initial=None, return_last_states=False, eps=1e-06, chunk_size=64, autocast_kernel_dtype=jnp.float32, stabilize_correctly=True, norm_val=1.0)#
Apply the mLSTM chunkwise formulation with Triton kernels.
Supports autograd application.
- Parameters:
q (jax.Array) – The query tensor of shape (B, NH, S, DHQK).
k (jax.Array) – The key tensor of shape (B, NH, S, DHQK).
v (jax.Array) – The value tensor of shape (B, NH, S, DHV).
i (jax.Array) – The input gate preactivation tensor of shape (B, NH, S).
f (jax.Array) – The forget gate preactivation tensor of shape (B, NH, S).
c_initial (jax.Array | None) – The initial chunk state tensor of shape (B, NH, DHQK, DHV).
n_initial (jax.Array | None) – The initial chunk state tensor of shape (B, NH, DHQK).
m_initial (jax.Array | None) – The initial chunk state tensor of shape (B, NH).
return_last_states (bool) – Whether to return the last states of the mLSTM.
eps (float) – The epsilon value to use for numerical stability.
chunk_size (int) – The chunk size to use for the mLSTM computation.
autocast_kernel_dtype (jax.numpy.dtype) – The dtype to use for the kernel computation. All inputs arguments up to vecF are cast to this dtype. vecF is automatically casted to float32 in the kernels.
stabilize_correctly (bool) – Whether to stabilize with max(norm_val*e^-m, |nk|) instead of max(norm_val, |nk|)
norm_val (float) – Norm scale in the max formula above
- Returns:
The output of the mLSTM computation. If return_last_states is True, the last states of the mLSTM are also returned.
- Return type:
jax.Array | tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]