xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple#
Classes#
Functions#
|
This is the mLSTM cell in parallel form. |
Module Contents#
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple.parallel_stabilized_simple(queries, keys, values, igate_preact, fgate_preact, lower_triangular_matrix=None, stabilize_rowwise=True, eps=1e-06, qkv_dtype=None, gate_dtype=None)#
This is the mLSTM cell in parallel form.
This version is stabilized. We control the range of exp() arguments by ensuring that they are always smaller than 0.0 by subtracting the maximum.
- Parameters:
queries (jax.Array) – (B, NH, S, DHQK)
keys (jax.Array) – (B, NH, S, DHQK)
values (jax.Array) – (B, NH, S, DHV)
igate_preact (jax.Array) – (B, NH, S, 1)
fgate_preact (jax.Array) – (B, NH, S, 1)
lower_triangular_matrix (jax.Array, optional) – (S,S). Defaults to None.
stabilize_rowwise (bool, optional) – Wether to stabilize the combination matrix C rowwise (take maximum per row). Alternative: Subtract the maximum over all rows. Defaults to True.
eps (float, optional) – Small value to avoid division by zero. Defaults to 1e-6.
qkv_dtype (jnp.dtype, optional) – dtype of queries, keys and values. Defaults to None, which infers the dtype from the inputs.
gate_dtype (jnp.dtype, optional) – dtype of igate_preact and fgate_preact. Defaults to None, which infers the dtype from the inputs.
- Returns:
(B, NH, S, DH), h_tilde_state
- Return type:
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple.mLSTMBackendParallelConfig#
-
- assign_model_config_params(model_config)#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple.mLSTMBackendParallel#
Bases:
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend- config_class#