xlstm_jax.models.xlstm_parallel.blocks.xlstm_block#

Classes#

xLSTMBlockConfig

ResidualBlock

A residual block that applies a list of modules in sequence and adds the input to the output.

xLSTMBlock

An xLSTM block can be either an sLSTM Block or an mLSTM Block.

Module Contents#

class xlstm_jax.models.xlstm_parallel.blocks.xlstm_block.xLSTMBlockConfig#
mlstm: xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer.mLSTMLayerConfig | None = None#
slstm: None = None#
parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
feedforward: xlstm_jax.models.xlstm_parallel.components.feedforward.FeedForwardConfig | None = None#
dtype: str = 'bfloat16'#
norm_eps: float = 1e-06#

Epsilon value for numerical stability in layer norm.

norm_type: Literal['layernorm', 'rmsnorm'] = 'layernorm'#

Type of normalization layer to use.

add_post_norm: bool = False#

If True, adds a normalization layer after the mLSTM/sLSTM layer and the feedforward layer. Note that this is not the post-norm on the residual connection, but is applied to the output of the layers before the residual connection, following e.g. Gemma-2.

_num_blocks: int | None = None#
_block_idx: int | None = None#
property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

class xlstm_jax.models.xlstm_parallel.blocks.xlstm_block.ResidualBlock#

Bases: flax.linen.Module

A residual block that applies a list of modules in sequence and adds the input to the output.

Modules are created within this block to wrap them as children module of this one.

module_fns: list[collections.abc.Callable[Ellipsis, flax.linen.Module]]#
class xlstm_jax.models.xlstm_parallel.blocks.xlstm_block.xLSTMBlock#

Bases: flax.linen.Module

An xLSTM block can be either an sLSTM Block or an mLSTM Block.

It contains the pre-LayerNorms and the skip connections.

config: xLSTMBlockConfig#