xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw#

Triton backend for the backward pass of the mLSTM chunkwise formulation.

This file has been adapted from the original PyTorch Triton implementation to JAX. For the Triton kernels, see mlstm_kernels/mlstm_kernels/mlstm/chunkwise/triton_fwbw_stablef.py.

In this file, we use the following notation:

Dimensions:

B: batch size H: number of heads S: sequence length (K, V) T: sequence length (Q) K: hidden dimension (Q, K) V: hidden dimension (H, V) NT: number of chunks BT: chunk size

Functions#

_mlstm_chunkwise_fwbw_generator([...])

Generate a forward and backward pass function for the mLSTM kernels with chunkwise formulation.

_get_chunkwise_fwbw_kernel(autocast_kernel_dtype, **kwargs)

Get the forward and backward pass function for the mLSTM kernels with chunkwise formulation.

mlstm_chunkwise_triton_stablef(q, k, v, i, f[, ...])

Apply the mLSTM chunkwise formulation with Triton kernels.

Module Contents#

xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw._mlstm_chunkwise_fwbw_generator(autocast_kernel_dtype=jnp.bfloat16, return_last_states=False, recompute_states_in_bw=True, chunk_size=64, eps=1e-06, stabilize_correctly=True, norm_val=1.0)#

Generate a forward and backward pass function for the mLSTM kernels with chunkwise formulation.

Parameters:
  • autocast_kernel_dtype (jax.numpy.dtype) – The dtype to use for the kernel computation. All inputs arguments up to vecF are cast to this dtype. vecF is automatically casted to float32 in the kernels.

  • return_last_states (bool) – Whether to return the last states of the mLSTM.

  • recompute_states_in_bw (bool) – Whether to recompute the mLSTM states in the backward pass.

  • chunk_size (int) – The chunk size to use for the mLSTM computation.

  • eps (float) – The epsilon value to use for numerical stability.

  • stabilize_correctly (bool) – Whether to stabilize with max(norm_val*e^-m, |nk|) instead of max(norm_val, |nk|)

  • norm_val (float) – Norm scale in the max formula above

Returns:

A function that computes the forward pass of the mLSTM chunkwise formulation, which custom gradients for the backward pass. The function input signatures is:

forward(

matQ: jax.Array, # (B, NH, S, DHQK) matK: jax.Array, # (B, NH, S, DHQK) matV: jax.Array, # (B, NH, S, DHV) vecI: jax.Array, # (B, NH, S) vecF: jax.Array, # (B, NH, S) matC_initial: jax.Array | None = None, # (B, NH, DHQK, DHV) vecN_initial: jax.Array | None = None, # (B, NH, DHQK) scaM_initial: jax.Array | None = None, # (B, NH)

) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:

The function returns the output of the mLSTM computation, and the last states internal states of C, N and M.

Return type:

collections.abc.Callable[[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array], tuple[jax.Array, jax.Array, jax.Array, jax.Array]]

xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw._get_chunkwise_fwbw_kernel(autocast_kernel_dtype, **kwargs)#

Get the forward and backward pass function for the mLSTM kernels with chunkwise formulation.

Parameters:
  • autocast_kernel_dtype (jax.numpy.dtype) – The dtype to use for the kernel computation. All inputs arguments up to vecF and vecI are cast to this dtype. vecF is automatically casted to float32 in the kernels.

  • **kwargs – Additional keyword arguments to pass to the kernel function.

Returns:

A function that computes the forward pass of the mLSTM chunkwise formulation, which custom gradients for the backward pass. See _mlstm_chunkwise_fwbw_generator for the function signature.

Return type:

collections.abc.Callable

xlstm_jax.kernels.mlstm_chunkwise.triton_stablef.triton_fwbw.mlstm_chunkwise_triton_stablef(q, k, v, i, f, c_initial=None, n_initial=None, m_initial=None, return_last_states=False, eps=1e-06, chunk_size=64, autocast_kernel_dtype=jnp.float32, stabilize_correctly=True, norm_val=1.0)#

Apply the mLSTM chunkwise formulation with Triton kernels.

Supports autograd application.

Parameters:
  • q (jax.Array) – The query tensor of shape (B, NH, S, DHQK).

  • k (jax.Array) – The key tensor of shape (B, NH, S, DHQK).

  • v (jax.Array) – The value tensor of shape (B, NH, S, DHV).

  • i (jax.Array) – The input gate preactivation tensor of shape (B, NH, S).

  • f (jax.Array) – The forget gate preactivation tensor of shape (B, NH, S).

  • c_initial (jax.Array | None) – The initial chunk state tensor of shape (B, NH, DHQK, DHV).

  • n_initial (jax.Array | None) – The initial chunk state tensor of shape (B, NH, DHQK).

  • m_initial (jax.Array | None) – The initial chunk state tensor of shape (B, NH).

  • return_last_states (bool) – Whether to return the last states of the mLSTM.

  • eps (float) – The epsilon value to use for numerical stability.

  • chunk_size (int) – The chunk size to use for the mLSTM computation.

  • autocast_kernel_dtype (jax.numpy.dtype) – The dtype to use for the kernel computation. All inputs arguments up to vecF are cast to this dtype. vecF is automatically casted to float32 in the kernels.

  • stabilize_correctly (bool) – Whether to stabilize with max(norm_val*e^-m, |nk|) instead of max(norm_val, |nk|)

  • norm_val (float) – Norm scale in the max formula above

Returns:

The output of the mLSTM computation. If return_last_states is True, the last states of the mLSTM are also returned.

Return type:

jax.Array | tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]