xlstm_jax.models.xlstm_parallel.blocks.mlstm.block#
Classes#
Sub-model configuration. |
Functions#
|
|
|
Module Contents#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.mLSTMBlockConfig#
Bases:
xlstm_jax.models.configs.SubModelConfigSub-model configuration.
This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.
- feedforward: xlstm_jax.models.xlstm_parallel.components.feedforward.FeedForwardConfig | None = None#
- add_post_norm: bool = False#
If True, adds a normalization layer after the mLSTM layer and the feedforward layer.
- parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
Parallel configuration for the model.
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.get_partial_mLSTMBlock(config, *args, **kwargs)#
- Parameters:
config (mLSTMBlockConfig)
- Return type:
callable
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.mLSTMBlock(config, *args, **kwargs)#
- Parameters:
config (mLSTMBlockConfig)
- Return type:
flax.linen.Module