xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer#

Classes#

mLSTMLayerConfig

Sub-model configuration.

mLSTMLayer

The mLSTM layer with Mamba block style.

mLSTMInnerLayer

The inner mLSTM layer with Mamba block style.

Module Contents#

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer.mLSTMLayerConfig#

Bases: xlstm_jax.models.xlstm_parallel.utils.UpProjConfigMixin

Sub-model configuration.

This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.

conv1d_kernel_size: int = 4#
qkv_proj_blocksize: int = 4#
num_heads: int = 4#
proj_factor: float = 2.0#
vmap_qk: bool = False#
init_distribution: xlstm_jax.models.shared.InitDistribution = 'normal'#

Distribution type from which to sample the weights.

output_init_fn: xlstm_jax.models.shared.InitFnName = 'wang'#

Initialization function for the output projection layer.

layer_type: Literal['mlstm', 'mlstm_v1'] = 'mlstm'#
norm_type: Literal['layernorm', 'rmsnorm'] = 'layernorm'#

this is only used in the ‘mlstm’ layer_type.

Type:

Type of normalization layer to use. NOTE

qk_dim_factor: float = 1.0#

Factor to scale the qk projection dimension by. By default, the qk projection dimension is the same as the inner embedding dimension, split into num_heads. This factor is applied to this default size.

v_dim_factor: float = 1.0#

Factor to scale the v projection dimension by. By default, the v projection dimension is the same as the inner embedding dimension, split into num_heads. This factor is applied to this default size.

embedding_dim: int = -1#
bias: bool = False#
dropout: float = 0.0#
context_length: int = -1#
dtype: str = 'bfloat16'#
parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
debug_cell: bool = False#
gate_input: Literal['qkv', 'x_mlstm', 'x_mlstm_conv', 'x_mlstm_conv_act'] = 'qkv'#

Which input to use for the mLSTM cell gates. Options are: - “qkv”: use the query, key and value vectors concatenated as input. Default, as in paper version. - “x_mlstm”: use the output of the mLSTM up projection layer. These are the same features that go into

the V projection.

  • “x_mlstm_conv”: use the output of the convolution on the mLSTM up projection features.

  • “x_mlstm_conv_act”: use the output of the activation function on the convolution on the mLSTM up projection

    features. These are the same features that go into the QK projection.

_num_blocks: int = 1#
_inner_embedding_dim: int = None#
mlstm_cell: xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell.mLSTMCellConfig#
property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

round_proj_up_dim_up: bool = True#
round_proj_up_to_multiple_of: int = 64#
_proj_up_dim: int | None = None#
_set_proj_up_dim(embedding_dim)#
Parameters:

embedding_dim (int)

Return type:

None

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer.mLSTMLayer#

Bases: flax.linen.Module

The mLSTM layer with Mamba block style.

config: mLSTMLayerConfig#
class xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer.mLSTMInnerLayer#

Bases: flax.linen.Module

The inner mLSTM layer with Mamba block style.

Applies a convolutional layer followed by a mLSTM cell.

config: mLSTMLayerConfig#