xlstm_jax.models.llama.attention#

Attributes#

Classes#

SelfAttentionConfig

Configuration for the self attention module.

SelfAttention

Self attention module with support for rotary embeddings.

Functions#

precompute_freqs(feat_dim[, pos_idx, max_length, ...])

Compute the sine and cosine frequencies for the rotary embeddings.

apply_rotary_emb(xq, xk[, freqs_sin, freqs_cos, theta])

Apply the rotary embeddings to the queries and keys.

segment_mask(segment_ids)

Create a mask for the self attention module based on the segment IDs.

multihead_attention(q, k, v[, segment_ids, causal, ...])

Compute multi-head self attention.

Module Contents#

xlstm_jax.models.llama.attention.AttentionBackend#
xlstm_jax.models.llama.attention.precompute_freqs(feat_dim, pos_idx=None, max_length=None, theta=10000.0, dtype=jnp.float32)#

Compute the sine and cosine frequencies for the rotary embeddings.

Parameters:
  • feat_dim (int) – Feature dimension of the input.

  • pos_idx (jax.Array | None) – Positional indices of the tokens in the input sequence. If None, uses an arange up to max_length.

  • max_length (int | None) – Maximum length of the input sequence. Only used if pos is None.

  • theta (float) – Theta parameter for the wave length calculation.

  • dtype (jax.numpy.dtype) – Data type of the returned frequencies.

Returns:

Tuple of the sine and cosine frequencies, shape (B, S, D//2). If pos_idx is None, shape is (1, S, D//2).

Return type:

tuple[jax.Array, jax.Array]

xlstm_jax.models.llama.attention.apply_rotary_emb(xq, xk, freqs_sin=None, freqs_cos=None, theta=10000.0)#

Apply the rotary embeddings to the queries and keys.

Parameters:
  • xq (jax.Array) – Array containing the query features of shape (B, S, NH, DHQK).

  • xk (jax.Array) – Array containing the key features of shape (B, S, NH, DHQK).

  • freqs_sin (jax.Array | None) – Sine frequencies for the rotary embeddings. If None, computes them based on the shape of xq.

  • freqs_cos (jax.Array | None) – Cosine frequencies for the rotary embeddings. If None, computes them based on the shape of xq.

  • theta (float) – Theta parameter for calculating the frequencies.

Returns:

Tuple of the query and key features with the rotary embeddings applied.

Return type:

tuple[jax.Array, jax.Array]

xlstm_jax.models.llama.attention.segment_mask(segment_ids)#

Create a mask for the self attention module based on the segment IDs.

Parameters:

segment_ids (jax.Array) – Segment IDs for the input tensor. The attention weights between elements of different segments is set to zero. Shape (B, S).

Returns:

Boolean tensor of shape (B, 1, S, S).

Return type:

jax.Array

class xlstm_jax.models.llama.attention.SelfAttentionConfig#

Bases: xlstm_jax.models.configs.SubModelConfig

Configuration for the self attention module.

head_dim: int = 64#

Dimension of the attention heads. Number of heads is inferred from the head and embedding dimensions.

qk_norm: bool = False#

Whether to apply RMSNorm to the query and key tensors.

use_bias: bool = False#

Whether to use bias in the linear layers of the self attention module.

dropout_rate: float = 0.0#

Dropout rate for the self attention module. Only applied during training.

num_layers: int = 12#

Number of layers in the Llama model. Used for initialization.

dtype: str = 'float32'#

Data type of the activations in the network.

attention_backend: AttentionBackend = 'xla'#

Which backend to use for the attention module. If triton or cudnn, respective Flash Attention kernels are used. cudnn is only supported for GPU backends, pallas_triton for both CPU and GPU backends, and xla on all backends.

causal: bool = True#

Whether to use causal attention masking for the self attention module.

parallel: xlstm_jax.models.configs.ParallelConfig#

Parallel configuration.

property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict

class xlstm_jax.models.llama.attention.SelfAttention#

Bases: flax.linen.Module

Self attention module with support for rotary embeddings.

Parameters:

config – Configuration for the self attention module.

config: SelfAttentionConfig#
xlstm_jax.models.llama.attention.multihead_attention(q, k, v, segment_ids=None, causal=True, qk_scale=None, backend='xla')#

Compute multi-head self attention.

Parameters:
  • q (jax.Array) – Query tensor of shape (B, S, NH, DHQK).

  • k (jax.Array) – Key tensor of shape (B, S, NH, DHQK).

  • v (jax.Array) – Value tensor of shape (B, S, NH, DHV).

  • segment_ids (jax.Array | None) – Segment IDs for the input tensor. The attention weights between elements of different segments is set to zero. If None, all elements are treated as belonging to the same segment, i.e. no masking.

  • causal (bool) – Whether to use causal attention masking for the self attention module.

  • qk_scale (float | None) – Scaling factor for the query-key logits. If None, defaults to 1/sqrt(DHQK). The scaling factor is applied to the query tensor before the dot product.

  • backend (AttentionBackend) – Which backend to use for the attention module. If triton or cudnn, respective Flash Attention kernels are used. cudnn is only supported for GPU backends, pallas_triton for both CPU and GPU backends, and xla on all backends.

Returns:

Output tensor of shape (B, S, NH, DHV).

Return type:

jax.Array