xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple#

Classes#

Functions#

parallel_stabilized_simple(queries, keys, values, ...)

This is the mLSTM cell in parallel form.

Module Contents#

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple.parallel_stabilized_simple(queries, keys, values, igate_preact, fgate_preact, lower_triangular_matrix=None, stabilize_rowwise=True, eps=1e-06, qkv_dtype=None, gate_dtype=None)#

This is the mLSTM cell in parallel form.

This version is stabilized. We control the range of exp() arguments by ensuring that they are always smaller than 0.0 by subtracting the maximum.

Parameters:
  • queries (jax.Array) – (B, NH, S, DHQK)

  • keys (jax.Array) – (B, NH, S, DHQK)

  • values (jax.Array) – (B, NH, S, DHV)

  • igate_preact (jax.Array) – (B, NH, S, 1)

  • fgate_preact (jax.Array) – (B, NH, S, 1)

  • lower_triangular_matrix (jax.Array, optional) – (S,S). Defaults to None.

  • stabilize_rowwise (bool, optional) – Wether to stabilize the combination matrix C rowwise (take maximum per row). Alternative: Subtract the maximum over all rows. Defaults to True.

  • eps (float, optional) – Small value to avoid division by zero. Defaults to 1e-6.

  • qkv_dtype (jnp.dtype, optional) – dtype of queries, keys and values. Defaults to None, which infers the dtype from the inputs.

  • gate_dtype (jnp.dtype, optional) – dtype of igate_preact and fgate_preact. Defaults to None, which infers the dtype from the inputs.

Returns:

(B, NH, S, DH), h_tilde_state

Return type:

jax.Array

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple.mLSTMBackendParallelConfig#
context_length: int = -1#
assign_model_config_params(model_config)#
class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.simple.mLSTMBackendParallel#

Bases: xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend

config_class#
property can_vmap_over_heads: bool#

Whether the backend can be vmaped over the heads dimension.

The backend is written independent of the heads dimension, and thus can be vmapped.

Returns:

True

Return type:

bool