xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention#

Classes#

Functions#

precompute_freqs(feat_dim, max_length[, theta])

Compute the sine and cosine frequencies for the rotary embeddings.

apply_rotary_emb(xq, xk[, freqs_sin, freqs_cos, theta])

Apply the rotary embeddings to the queries and keys.

attention(queries, keys, values[, attention_mask, ...])

This is an attention backend that mimics the attention mechanism of the transformer.

Module Contents#

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.precompute_freqs(feat_dim, max_length, theta=10000.0)#

Compute the sine and cosine frequencies for the rotary embeddings.

Parameters:
  • feat_dim (int) – Feature dimension of the input.

  • max_length (int) – Maximum length of the input sequence.

  • theta (float) – Theta parameter for the wave length calculation.

Returns:

Tuple of the sine and cosine frequencies.

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.apply_rotary_emb(xq, xk, freqs_sin=None, freqs_cos=None, theta=10000.0)#

Apply the rotary embeddings to the queries and keys.

Parameters:
  • xq (jax.Array) – Array containing the query features of shape (B, NH, S, DHQK).

  • xk (jax.Array) – Array containing the key features of shape (B, NH, S, DHQK).

  • freqs_sin (jax.Array | None) – Sine frequencies for the rotary embeddings. If None, computes them based on the shape of xq.

  • freqs_cos (jax.Array | None) – Cosine frequencies for the rotary embeddings. If None, computes them based on the shape of xq.

  • theta (float) – Theta parameter for calculating the frequencies.

Returns:

Tuple of the query and key features with the rotary embeddings applied.

Return type:

tuple[jax.Array, jax.Array]

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.attention(queries, keys, values, attention_mask=None, qkv_dtype=None, activation_function='softmax', qk_pre_activation_function='none', theta=10000.0)#

This is an attention backend that mimics the attention mechanism of the transformer.

Note that no forget and input gates are applied here.

Parameters:
  • queries (jax.Array) – Array containing the query features of shape (B, NH, S, DHQK).

  • keys (jax.Array) – Array containing the key features of shape (B, NH, S, DHQK).

  • values (jax.Array) – Array containing the value features of shape (B, NH, S, DHV).

  • attention_mask (jax.Array) – Array of shape (S,S) denoting the attention mask. By default, uses a causal mask which is a lower triangular matrix. Dtype should be bool, where False denotes masked positions.

  • qkv_dtype (jax.numpy.dtype | None) – Dtype of the queries, keys and values. If None, uses the dtype of queries.

  • activation_function (Literal['softmax', 'sigmoid', 'none']) – Activation function to apply on the attention logits. Softmax is performed over the key sequence as in default transformers. Sigmoid is applied with a bias of -log(S).

  • qk_pre_activation_function (Literal['silu', 'swish', 'none']) – Activation function to apply on the queries and keys before computing the attention logits.

  • theta (float) – Theta parameter for the rotary embeddings.

Returns:

The output features of the attention of shape (B, NH, S, DHV).

Return type:

jax.Array

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.mLSTMBackendAttentionConfig#
context_length: int = -1#
activation_function: Literal['softmax', 'sigmoid', 'none'] = 'softmax'#

Activation function to apply on the attention logits. Softmax is performed over the key sequence as in default transformers. Sigmoid is applied with a bias of -log(context_length).

qk_pre_activation_function: Literal['swish', 'none'] = 'none'#

Activation function to apply on the queries and keys before computing the attention logits.

theta: float = 10000.0#

Theta parameter for the rotary embeddings.

assign_model_config_params(model_config)#
class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.mLSTMBackendAttention#

Bases: xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend

config_class#
property can_vmap_over_heads: bool#

Whether the backend can be vmaped over the heads dimension.

The backend is written independent of the heads dimension, and thus can be vmapped.

Returns:

True

Return type:

bool