xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent#
Classes#
Functions#
|
This is a single step of the mLSTM operation in recurrent form. |
|
Forward pass of the mLSTM cell in recurrent form on a full sequence. |
Module Contents#
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent.recurrent_step_fw(matC_state, vecN_state, scaM_state, vecQ, vecK, vecV, scaI, scaF, eps=1e-06)#
This is a single step of the mLSTM operation in recurrent form.
- Parameters:
matC_state (jax.Array) – Memory state tensor of shape (B, NH, DHQK, DHV).
vecN_state (jax.Array) – Normalizer state tensor of shape (B, NH, DHQK).
scaM_state (jax.Array) – Max state tensor of shape (B, NH, 1).
vecQ (jax.Array) – Queries tensor of shape (B, NH, DHQK).
vecK (jax.Array) – Keys tensor of shape (B, NH, DHQK).
vecV (jax.Array) – Values tensor of shape (B, NH, DHV).
scaI (jax.Array) – Input gate tensor of shape (B, NH, 1).
scaF (jax.Array) – Forget gate tensor of shape (B, NH, 1).
eps (float) – Used for building the forgetgate matrix. Defaults to 1e-6.
- Returns:
The hidden state and the new states (matC_state_new, vecN_state_new, vecM_state_new).
- Return type:
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent.recurrent_sequence_fw(queries, keys, values, igate_preact, fgate_preact, c_initial=None, n_initial=None, m_initial=None, return_last_states=False, eps=1e-06, state_dtype=None, use_scan=False, mlstm_step_fn=recurrent_step_fw)#
Forward pass of the mLSTM cell in recurrent form on a full sequence.
- Parameters:
queries (jax.Array) – Queries tensor of shape (B, NH, S, DHQK).
keys (jax.Array) – Keys tensor of shape (B, NH, S, DHQK).
values (jax.Array) – Values tensor of shape (B, NH, S, DHV).
igate_preact (jax.Array) – Input gate pre-activation tensor of shape (B, NH, S, 1).
fgate_preact (jax.Array) – Forget gate pre-activation tensor of shape (B, NH, S, 1).
c_initial (jax.Array | None) – Initial memory state tensor of shape (B, NH, DHQK, DHV). If None, initialized to zeros.
n_initial (jax.Array | None) – Initial normalizer state tensor of shape (B, NH, DHQK). If None, initialized to zeros.
m_initial (jax.Array | None) – Initial max state tensor of shape (B, NH). If None, initialized to zeros.
return_last_states (bool) – Whether to return the last states. Defaults to False.
eps (float) – Epsilon value for numerical stability. Defaults to 1e-6.
state_dtype (jax.numpy.dtype | None) – Dtype of the states. If None, uses the dtype of the initial states if provided, or other the dtypes of the pre-activations. If initial states are provided, the return dtype will be the same as the initial states. Defaults to None.
use_scan (bool) – Whether to use jax.lax.scan for the loop. The scan reduces compilation time, but may be slower for kernels without XLA compiler support and introduces memory copy overhead.
mlstm_step_fn (collections.abc.Callable) – Function to compute a single mLSTM step. By default, set to recurrent_step_fw in this backend.
- Returns:
Hidden states tensor of shape (B, NH, S, DHV) if return_last_states is False. Tuple of hidden states tensor and tuple of last states tensors if return_last_states is True.
- Return type:
jax.Array | tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent.mLSTMBackendRecurrentConfig#
-
- assign_model_config_params(model_config)#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent.mLSTMBackendRecurrent#
Bases:
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend- config_class#
- property can_vmap_over_heads: bool#
Whether the backend can be vmaped over the heads dimension.
The backend was not written to be independent of the heads dimension, and thus cannot be vmapped.
- Returns:
False
- Return type:
- config: Any#