xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend_utils#
Functions#
|
Execute the mLSTM backend for the given input tensors. |
Module Contents#
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend_utils.run_backend(parent, cell_config, q, k, v, igate_preact, fgate_preact)#
Execute the mLSTM backend for the given input tensors.
This function handles the caching of intermediate states, if enabled, and the vmap over the heads dimension. The caching follows the setup of the cache in the Attention module: google/flax. During decoding, if the cache is not initialized, the cache is initialized with zeros and we do not update the cache, following the setup in the Attention module. If the cache is provided, we update the cache with the new states.
- Parameters:
parent (flax.linen.Module) – The parent module.
cell_config (Any) – The mLSTM cell configuration.
q (jax.Array) – The query tensor, shape (batch_size, seq_len, num_heads, qk_dim).
k (jax.Array) – The key tensor, shape (batch_size, seq_len, num_heads, qk_dim).
v (jax.Array) – The value tensor, shape (batch_size, seq_len, num_heads, v_dim).
igate_preact (jax.Array) – The input gate preactivation, shape (batch_size, seq_len, num_heads, 1).
fgate_preact (jax.Array) – The forget gate preactivation, shape (batch_size, seq_len, num_heads, 1).
- Returns:
The output tensor, shape (batch_size, seq_len, num_heads, v_dim).
- Return type: