xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw#

Classes#

mLSTMBackendFwbwConfig

Configuration object for the mLSTM fwbw backend.

mLSTMBackendFwbw

Functions#

rev_cumsum_off(x)

Compute the reverse cumulative sum of a tensor with an offset.

rev_cumsum(x)

Compute the reverse cumulative sum of a tensor.

causal_forget_matrix(forget_gates)

Compute the causal forget matrix from the forget gates.

fwbw_forward(q, k, v, i, f, config[, initial_C, ...])

Forward pass of the mLSTM fwbw backend.

fwbw_backward(ctx, dh, config[, dc_last, dn_last, dm_last])

Backward pass of the mLSTM fwbw backend.

mlstm_fwbw_custom_grad(config)

Returns an autograd function that computes the gradient itself.

Module Contents#

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.mLSTMBackendFwbwConfig#

Configuration object for the mLSTM fwbw backend.

chunk_size: int | None = 512#

Chunk size for the kernel computation.

return_state: bool = False#

Whether to return the last state. USeful for inference.

use_initial_state: bool = False#

Whether to start from an initial state or zeros.

keep_G: bool = False#

Whether to save the G matrix for the backward pass.

keep_gates: bool = True#

Whether to save the gates for the backward pass.

keep_M: bool = False#

Whether to save the M matrix for the backward pass.

keep_c: bool = False#

Whether to save the c matrix for the backward pass.

stabilize_correctly: bool = False#

Whether to stabilize the output correctly. This is only needed if no GroupNorm is applied after the mLSTM. If GroupNorm is applied, this can be set to False, as results after GroupNorm will be the same.

assign_model_config_params(model_config)#
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.rev_cumsum_off(x)#

Compute the reverse cumulative sum of a tensor with an offset.

Parameters:

x (jax.Array)

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.rev_cumsum(x)#

Compute the reverse cumulative sum of a tensor.

Parameters:

x (jax.Array)

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.causal_forget_matrix(forget_gates)#

Compute the causal forget matrix from the forget gates.

Parameters:

forget_gates (jax.Array)

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.fwbw_forward(q, k, v, i, f, config, initial_C=None, initial_n=None, initial_m=None)#

Forward pass of the mLSTM fwbw backend.

Parameters:
Returns:

Output tensor and context for backward.

Return type:

tuple[jax.Array, collections.abc.Sequence[jax.Array]]

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.fwbw_backward(ctx, dh, config, dc_last=None, dn_last=None, dm_last=None)#

Backward pass of the mLSTM fwbw backend.

Parameters:
  • ctx (Sequence[jax.Array]) – context from forward pass.

  • dh (jax.Array) – gradient tensor.

  • config (mLSTMfwbwConfig) – configuration object.

  • dc_last (jax.Array, optional) – last chunk tensor. Defaults to None.

  • dn_last (jax.Array, optional) – last n tensor. Defaults to None.

  • dm_last (jax.Array, optional) – last m tensor. Defaults to None.

Returns:

tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array,

jax.Array | None, jax.Array | None, jax.Array | None]: gradients.

Return type:

tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array | None, jax.Array | None, jax.Array | None]

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.mlstm_fwbw_custom_grad(config)#

Returns an autograd function that computes the gradient itself.

Parameters:

config (mLSTMfwbwConfig) – configuration object.

Returns:

autograd function.

Return type:

function

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.mLSTMBackendFwbw#

Bases: xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend

config_class#
property can_vmap_over_heads: bool#

Whether the backend can be vmaped over the heads dimension.

The backend is written independent of the heads dimension, and thus can be vmapped.

Returns:

True

Return type:

bool