xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple#
Classes#
Functions#
|
This is the mLSTM cell in parallel form. |
|
This is a single step of the mLSTM operation in recurrent form. |
Module Contents#
- xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple.parallel_stabilized_simple(queries, keys, values, igate_preact, fgate_preact, lower_triangular_matrix=None, stabilize_rowwise=True, eps=1e-06)#
This is the mLSTM cell in parallel form.
This version is stabilized. We control the range of exp() arguments by ensuring that they are always smaller than 0.0 by subtracting the maximum.
- Parameters:
queries (jax.Array) – (B, NH, S, DH)
keys (jax.Array) – (B, NH, S, DH)
values (jax.Array) – (B, NH, S, DH)
igate_preact (jax.Array) – (B, NH, S, 1)
fgate_preact (jax.Array) – (B, NH, S, 1)
lower_triangular_matrix (jax.Array) – (S,S). Defaults to None.
stabilize_rowwise (bool) – Whether to stabilize the combination matrix C row-wise (take maximum per row). Alternative: Subtract the maximum over all rows. Defaults to True.
eps (float) – Epsilon value. Defaults to 1e-6.
- Returns:
(B, NH, S, DH), h_tilde_state
- Return type:
- class xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple.mLSTMBackendJaxConfig#
-
- assign_model_config_params(model_config)#
- class xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple.mLSTMBackendJax#
Bases:
xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.config.mLSTMBackend- config_class#
- xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.simple.recurrent_step_stabilized_simple(c_state, n_state, m_state, q, k, v, igate_preact, fgate_preact, eps=1e-06)#
This is a single step of the mLSTM operation in recurrent form.
- Parameters:
c_state (jax.Array) – (B, NH, DH, DH)
n_state (jax.Array) – (B, NH, DH, 1)
m_state (jax.Array) – (B, NH, 1, 1)
q (jax.Array) – (B, NH, 1, DH)
k (jax.Array) – (B, NH, 1, DH)
v (jax.Array) – (B, NH, 1, DH)
igate_preact (jax.Array) – (B, NH, 1, 1)
fgate_preact (jax.Array) – (B, NH, 1, 1)
eps (float) – Epsilon value. Defaults to 1e-6.
- Returns:
- (hidden_state [B, NH, DH],
(c_state_new [B, NH, DH, DH], n_state_new [B, NH, DH, 1], m_state_new [B, NH, 1, 1]))
- Return type: