xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention#
Classes#
Functions#
|
Compute the sine and cosine frequencies for the rotary embeddings. |
|
Apply the rotary embeddings to the queries and keys. |
|
This is an attention backend that mimics the attention mechanism of the transformer. |
Module Contents#
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.precompute_freqs(feat_dim, max_length, theta=10000.0)#
Compute the sine and cosine frequencies for the rotary embeddings.
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.apply_rotary_emb(xq, xk, freqs_sin=None, freqs_cos=None, theta=10000.0)#
Apply the rotary embeddings to the queries and keys.
- Parameters:
xq (jax.Array) – Array containing the query features of shape (B, NH, S, DHQK).
xk (jax.Array) – Array containing the key features of shape (B, NH, S, DHQK).
freqs_sin (jax.Array | None) – Sine frequencies for the rotary embeddings. If None, computes them based on the shape of xq.
freqs_cos (jax.Array | None) – Cosine frequencies for the rotary embeddings. If None, computes them based on the shape of xq.
theta (float) – Theta parameter for calculating the frequencies.
- Returns:
Tuple of the query and key features with the rotary embeddings applied.
- Return type:
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.attention(queries, keys, values, attention_mask=None, qkv_dtype=None, activation_function='softmax', qk_pre_activation_function='none', theta=10000.0)#
This is an attention backend that mimics the attention mechanism of the transformer.
Note that no forget and input gates are applied here.
- Parameters:
queries (jax.Array) – Array containing the query features of shape (B, NH, S, DHQK).
keys (jax.Array) – Array containing the key features of shape (B, NH, S, DHQK).
values (jax.Array) – Array containing the value features of shape (B, NH, S, DHV).
attention_mask (jax.Array) – Array of shape (S,S) denoting the attention mask. By default, uses a causal mask which is a lower triangular matrix. Dtype should be bool, where False denotes masked positions.
qkv_dtype (jax.numpy.dtype | None) – Dtype of the queries, keys and values. If None, uses the dtype of queries.
activation_function (Literal['softmax', 'sigmoid', 'none']) – Activation function to apply on the attention logits. Softmax is performed over the key sequence as in default transformers. Sigmoid is applied with a bias of -log(S).
qk_pre_activation_function (Literal['silu', 'swish', 'none']) – Activation function to apply on the queries and keys before computing the attention logits.
theta (float) – Theta parameter for the rotary embeddings.
- Returns:
The output features of the attention of shape (B, NH, S, DHV).
- Return type:
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.mLSTMBackendAttentionConfig#
-
- activation_function: Literal['softmax', 'sigmoid', 'none'] = 'softmax'#
Activation function to apply on the attention logits. Softmax is performed over the key sequence as in default transformers. Sigmoid is applied with a bias of -log(context_length).
- qk_pre_activation_function: Literal['swish', 'none'] = 'none'#
Activation function to apply on the queries and keys before computing the attention logits.
- assign_model_config_params(model_config)#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.attention.mLSTMBackendAttention#
Bases:
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend- config_class#