xlstm_jax.models.xlstm_parallel.blocks.mlstm.block#

Classes#

mLSTMBlockConfig

Sub-model configuration.

Functions#

get_partial_mLSTMBlock(config, *args, **kwargs)

mLSTMBlock(config, *args, **kwargs)

Module Contents#

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.mLSTMBlockConfig#

Bases: xlstm_jax.models.configs.SubModelConfig

Sub-model configuration.

This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.

mlstm: xlstm_jax.models.xlstm_parallel.blocks.mlstm.layer.mLSTMLayerConfig#
feedforward: xlstm_jax.models.xlstm_parallel.components.feedforward.FeedForwardConfig | None = None#
add_post_norm: bool = False#

If True, adds a normalization layer after the mLSTM layer and the feedforward layer.

parallel: xlstm_jax.models.configs.ParallelConfig | None = None#

Parallel configuration for the model.

_num_blocks: int | None = None#
_block_idx: int | None = None#
to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict

xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.get_partial_mLSTMBlock(config, *args, **kwargs)#
Parameters:

config (mLSTMBlockConfig)

Return type:

callable

xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.mLSTMBlock(config, *args, **kwargs)#
Parameters:

config (mLSTMBlockConfig)

Return type:

flax.linen.Module