xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent#

Classes#

Functions#

recurrent_step_fw(matC_state, vecN_state, scaM_state, ...)

This is a single step of the mLSTM operation in recurrent form.

recurrent_sequence_fw(queries, keys, values, ...[, ...])

Forward pass of the mLSTM cell in recurrent form on a full sequence.

Module Contents#

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent.recurrent_step_fw(matC_state, vecN_state, scaM_state, vecQ, vecK, vecV, scaI, scaF, eps=1e-06)#

This is a single step of the mLSTM operation in recurrent form.

Parameters:
  • matC_state (jax.Array) – Memory state tensor of shape (B, NH, DHQK, DHV).

  • vecN_state (jax.Array) – Normalizer state tensor of shape (B, NH, DHQK).

  • scaM_state (jax.Array) – Max state tensor of shape (B, NH, 1).

  • vecQ (jax.Array) – Queries tensor of shape (B, NH, DHQK).

  • vecK (jax.Array) – Keys tensor of shape (B, NH, DHQK).

  • vecV (jax.Array) – Values tensor of shape (B, NH, DHV).

  • scaI (jax.Array) – Input gate tensor of shape (B, NH, 1).

  • scaF (jax.Array) – Forget gate tensor of shape (B, NH, 1).

  • eps (float) – Used for building the forgetgate matrix. Defaults to 1e-6.

Returns:

The hidden state and the new states (matC_state_new, vecN_state_new, vecM_state_new).

Return type:

tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent.recurrent_sequence_fw(queries, keys, values, igate_preact, fgate_preact, c_initial=None, n_initial=None, m_initial=None, return_last_states=False, eps=1e-06, state_dtype=None, use_scan=False, mlstm_step_fn=recurrent_step_fw)#

Forward pass of the mLSTM cell in recurrent form on a full sequence.

Parameters:
  • queries (jax.Array) – Queries tensor of shape (B, NH, S, DHQK).

  • keys (jax.Array) – Keys tensor of shape (B, NH, S, DHQK).

  • values (jax.Array) – Values tensor of shape (B, NH, S, DHV).

  • igate_preact (jax.Array) – Input gate pre-activation tensor of shape (B, NH, S, 1).

  • fgate_preact (jax.Array) – Forget gate pre-activation tensor of shape (B, NH, S, 1).

  • c_initial (jax.Array | None) – Initial memory state tensor of shape (B, NH, DHQK, DHV). If None, initialized to zeros.

  • n_initial (jax.Array | None) – Initial normalizer state tensor of shape (B, NH, DHQK). If None, initialized to zeros.

  • m_initial (jax.Array | None) – Initial max state tensor of shape (B, NH). If None, initialized to zeros.

  • return_last_states (bool) – Whether to return the last states. Defaults to False.

  • eps (float) – Epsilon value for numerical stability. Defaults to 1e-6.

  • state_dtype (jax.numpy.dtype | None) – Dtype of the states. If None, uses the dtype of the initial states if provided, or other the dtypes of the pre-activations. If initial states are provided, the return dtype will be the same as the initial states. Defaults to None.

  • use_scan (bool) – Whether to use jax.lax.scan for the loop. The scan reduces compilation time, but may be slower for kernels without XLA compiler support and introduces memory copy overhead.

  • mlstm_step_fn (collections.abc.Callable) – Function to compute a single mLSTM step. By default, set to recurrent_step_fw in this backend.

Returns:

Hidden states tensor of shape (B, NH, S, DHV) if return_last_states is False. Tuple of hidden states tensor and tuple of last states tensors if return_last_states is True.

Return type:

jax.Array | tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent.mLSTMBackendRecurrentConfig#
context_length: int = -1#
eps: float = 1e-06#
state_dtype: str | None = None#
use_scan: bool = False#
assign_model_config_params(model_config)#
class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent.mLSTMBackendRecurrent#

Bases: xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend

config_class#
property can_vmap_over_heads: bool#

Whether the backend can be vmaped over the heads dimension.

The backend was not written to be independent of the heads dimension, and thus cannot be vmapped.

Returns:

False

Return type:

bool

config: Any#