xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell#

Classes#

mLSTMCellConfig

Sub-model configuration.

mLSTMCell

Module Contents#

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell.mLSTMCellConfig#

Bases: xlstm_jax.models.configs.SubModelConfig

Sub-model configuration.

This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.

context_length: int = -1#
embedding_dim: int = -1#
num_heads: int = -1#
backend: xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.mLSTMBackendNameAndKwargs#
norm_eps: float = 1e-06#

Epsilon value for numerical stability in layer norm.

norm_type: Literal['layernorm', 'rmsnorm'] = 'layernorm'#

Type of normalization layer to use.

norm_type_v1: Literal['layernorm', 'rmsnorm'] = 'layernorm'#

this is only used in the ‘mlstm_v1’ layer_type. Due to a bug, the ‘norm_type’ was not used correctly in the v1 version. To keep the same behavior, we introduce a separate parameter for the normalization layer.

Type:

Type of normalization layer to use. NOTE

dtype: str = 'bfloat16'#
gate_dtype: str = 'float32'#
gate_soft_cap: float | None = None#

Soft cap for the gate pre-activations. If None, no cap is applied.

gate_linear_headwise: bool = False#

If True, the gate pre-activations are computed with a linear headwise layer, similar to QKV. Otherwise, each gate head takes as input the full features across all heads.

igate_bias_init_range: tuple[float, float] | float | None = None#

Input gate bias initialization. If a tuple, the bias is initialized with a linspace in the given range. If a float, the bias is initialized with the given value. If None, the bias is initialized with normal(0.1).

fgate_bias_init_range: tuple[float, float] | float | None = (3.0, 6.0)#

Forget gate bias initialization. If a tuple, the bias is initialized with a linspace in the given range. If a float, the bias is initialized with the given value. If None, the bias is initialized with normal(0.1).

add_qk_norm: bool = False#

If True, adds a normalization layer on the query and key vectors before the mLSTM cell.

reset_at_document_boundaries: bool = False#

If True, the memory is reset at the beginning of each document.

reset_fgate_value: float = -25.0#

Value to set the forget gate to at document boundaries.

parallel: xlstm_jax.models.configs.ParallelConfig | None = None#

Parallel configuration for the mLSTM cell.

property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

property _gate_dtype: jax.numpy.dtype#
Return type:

jax.numpy.dtype

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell.mLSTMCell#

Bases: flax.linen.Module

config: mLSTMCellConfig#