xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw#
Classes#
Configuration object for the mLSTM fwbw backend. |
|
Functions#
Compute the reverse cumulative sum of a tensor with an offset. |
|
|
Compute the reverse cumulative sum of a tensor. |
|
Compute the causal forget matrix from the forget gates. |
|
Forward pass of the mLSTM fwbw backend. |
|
Backward pass of the mLSTM fwbw backend. |
|
Returns an autograd function that computes the gradient itself. |
Module Contents#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.mLSTMBackendFwbwConfig#
Configuration object for the mLSTM fwbw backend.
- stabilize_correctly: bool = False#
Whether to stabilize the output correctly. This is only needed if no GroupNorm is applied after the mLSTM. If GroupNorm is applied, this can be set to False, as results after GroupNorm will be the same.
- assign_model_config_params(model_config)#
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.rev_cumsum_off(x)#
Compute the reverse cumulative sum of a tensor with an offset.
- Parameters:
x (jax.Array)
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.rev_cumsum(x)#
Compute the reverse cumulative sum of a tensor.
- Parameters:
x (jax.Array)
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.causal_forget_matrix(forget_gates)#
Compute the causal forget matrix from the forget gates.
- Parameters:
forget_gates (jax.Array)
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.fwbw_forward(q, k, v, i, f, config, initial_C=None, initial_n=None, initial_m=None)#
Forward pass of the mLSTM fwbw backend.
- Parameters:
q (jax.Array) – query tensor
k (jax.Array) – key tensor
v (jax.Array) – value tensor
i (jax.Array) – input gate tensor
f (jax.Array) – forget gate tensor
config (mLSTMBackendFwbwConfig) – configuration object
initial_C (jax.Array | None) – initial chunk tensor. Defaults to None.
initial_n (jax.Array | None) – initial n tensor. Defaults to None.
initial_m (jax.Array | None) – initial m tensor. Defaults to None.
- Returns:
Output tensor and context for backward.
- Return type:
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.fwbw_backward(ctx, dh, config, dc_last=None, dn_last=None, dm_last=None)#
Backward pass of the mLSTM fwbw backend.
- Parameters:
ctx (Sequence[jax.Array]) – context from forward pass.
dh (jax.Array) – gradient tensor.
config (mLSTMfwbwConfig) – configuration object.
dc_last (jax.Array, optional) – last chunk tensor. Defaults to None.
dn_last (jax.Array, optional) – last n tensor. Defaults to None.
dm_last (jax.Array, optional) – last m tensor. Defaults to None.
- Returns:
- tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array,
jax.Array | None, jax.Array | None, jax.Array | None]: gradients.
- Return type:
tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array | None, jax.Array | None, jax.Array | None]
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.mlstm_fwbw_custom_grad(config)#
Returns an autograd function that computes the gradient itself.
- Parameters:
config (mLSTMfwbwConfig) – configuration object.
- Returns:
autograd function.
- Return type:
function
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.fwbw.mLSTMBackendFwbw#
Bases:
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend- config_class#