xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer#
Classes#
Sub-model configuration. |
|
The mLSTM layer with Mamba block style. |
|
The inner mLSTM layer with Mamba block style. |
Module Contents#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer.mLSTMLayerConfig#
Bases:
xlstm_jax.models.xlstm_parallel.utils.UpProjConfigMixinSub-model configuration.
This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.
- init_distribution: xlstm_jax.models.shared.InitDistribution = 'normal'#
Distribution type from which to sample the weights.
- output_init_fn: xlstm_jax.models.shared.InitFnName = 'wang'#
Initialization function for the output projection layer.
- layer_type: Literal['mlstm', 'mlstm_v1'] = 'mlstm'#
- norm_type: Literal['layernorm', 'rmsnorm'] = 'layernorm'#
this is only used in the ‘mlstm’ layer_type.
- Type:
Type of normalization layer to use. NOTE
- qk_dim_factor: float = 1.0#
Factor to scale the qk projection dimension by. By default, the qk projection dimension is the same as the inner embedding dimension, split into num_heads. This factor is applied to this default size.
- v_dim_factor: float = 1.0#
Factor to scale the v projection dimension by. By default, the v projection dimension is the same as the inner embedding dimension, split into num_heads. This factor is applied to this default size.
- parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
- gate_input: Literal['qkv', 'x_mlstm', 'x_mlstm_conv', 'x_mlstm_conv_act'] = 'qkv'#
Which input to use for the mLSTM cell gates. Options are: - “qkv”: use the query, key and value vectors concatenated as input. Default, as in paper version. - “x_mlstm”: use the output of the mLSTM up projection layer. These are the same features that go into
the V projection.
“x_mlstm_conv”: use the output of the convolution on the mLSTM up projection features.
- “x_mlstm_conv_act”: use the output of the activation function on the convolution on the mLSTM up projection
features. These are the same features that go into the QK projection.
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer.mLSTMLayer#
Bases:
flax.linen.ModuleThe mLSTM layer with Mamba block style.
- config: mLSTMLayerConfig#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer.mLSTMInnerLayer#
Bases:
flax.linen.ModuleThe inner mLSTM layer with Mamba block style.
Applies a convolutional layer followed by a mLSTM cell.
- config: mLSTMLayerConfig#