xlstm_jax.models.xlstm_parallel.xlstm_block_stack#
Classes#
Sub-model configuration. |
|
Module Contents#
- class xlstm_jax.models.xlstm_parallel.xlstm_block_stack.xLSTMBlockStackConfig#
Bases:
xlstm_jax.models.configs.SubModelConfigSub-model configuration.
This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.
- mlstm_block: xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.mLSTMBlockConfig | None = None#
- parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
- init_distribution_embed: xlstm_jax.models.shared.InitDistribution = 'normal'#
Distribution type from which to sample the embeddings.
- init_distribution_out: xlstm_jax.models.shared.InitDistribution = 'normal'#
Distribution type from which to sample the LM output head.
- _create_block_map()#
Creates the block map, that specifies which block is used at which position.
- Return type:
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_parallel.xlstm_block_stack.xLSTMBlockStack#
Bases:
flax.linen.Module- config: xLSTMBlockStackConfig#
- class xlstm_jax.models.xlstm_parallel.xlstm_block_stack.BlockStack#
Bases:
flax.linen.Module- config: xLSTMBlockStackConfig#
- _create_blocks(config)#
- Parameters:
config (xLSTMBlockStackConfig)