xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell#
Classes#
Sub-model configuration. |
|
Module Contents#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell.mLSTMCellConfig#
Bases:
xlstm_jax.models.configs.SubModelConfigSub-model configuration.
This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.
- norm_type: Literal['layernorm', 'rmsnorm'] = 'layernorm'#
Type of normalization layer to use.
- norm_type_v1: Literal['layernorm', 'rmsnorm'] = 'layernorm'#
this is only used in the ‘mlstm_v1’ layer_type. Due to a bug, the ‘norm_type’ was not used correctly in the v1 version. To keep the same behavior, we introduce a separate parameter for the normalization layer.
- Type:
Type of normalization layer to use. NOTE
- gate_soft_cap: float | None = None#
Soft cap for the gate pre-activations. If None, no cap is applied.
- gate_linear_headwise: bool = False#
If True, the gate pre-activations are computed with a linear headwise layer, similar to QKV. Otherwise, each gate head takes as input the full features across all heads.
- igate_bias_init_range: tuple[float, float] | float | None = None#
Input gate bias initialization. If a tuple, the bias is initialized with a linspace in the given range. If a float, the bias is initialized with the given value. If None, the bias is initialized with normal(0.1).
- fgate_bias_init_range: tuple[float, float] | float | None = (3.0, 6.0)#
Forget gate bias initialization. If a tuple, the bias is initialized with a linspace in the given range. If a float, the bias is initialized with the given value. If None, the bias is initialized with normal(0.1).
- add_qk_norm: bool = False#
If True, adds a normalization layer on the query and key vectors before the mLSTM cell.
- reset_at_document_boundaries: bool = False#
If True, the memory is reset at the beginning of each document.
- parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
Parallel configuration for the mLSTM cell.
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- property _gate_dtype: jax.numpy.dtype#
- Return type:
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.cell.mLSTMCell#
Bases:
flax.linen.Module- config: mLSTMCellConfig#