xlstm_jax.models.xlstm_parallel.blocks.xlstm_block#
Classes#
A residual block that applies a list of modules in sequence and adds the input to the output. |
|
An xLSTM block can be either an sLSTM Block or an mLSTM Block. |
Module Contents#
- class xlstm_jax.models.xlstm_parallel.blocks.xlstm_block.xLSTMBlockConfig#
-
- parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
- feedforward: xlstm_jax.models.xlstm_parallel.components.feedforward.FeedForwardConfig | None = None#
- norm_type: Literal['layernorm', 'rmsnorm'] = 'layernorm'#
Type of normalization layer to use.
- add_post_norm: bool = False#
If True, adds a normalization layer after the mLSTM/sLSTM layer and the feedforward layer. Note that this is not the post-norm on the residual connection, but is applied to the output of the layers before the residual connection, following e.g. Gemma-2.
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_parallel.blocks.xlstm_block.ResidualBlock#
Bases:
flax.linen.ModuleA residual block that applies a list of modules in sequence and adds the input to the output.
Modules are created within this block to wrap them as children module of this one.
- module_fns: list[collections.abc.Callable[Ellipsis, flax.linen.Module]]#
- class xlstm_jax.models.xlstm_parallel.blocks.xlstm_block.xLSTMBlock#
Bases:
flax.linen.ModuleAn xLSTM block can be either an sLSTM Block or an mLSTM Block.
It contains the pre-LayerNorms and the skip connections.
- config: xLSTMBlockConfig#