xlstm_jax.models.xlstm_parallel.xlstm_lm_model#
Classes#
Sub-model configuration. |
|
Module Contents#
- class xlstm_jax.models.xlstm_parallel.xlstm_lm_model.xLSTMLMModelConfig#
Bases:
xlstm_jax.models.xlstm_parallel.xlstm_block_stack.xLSTMBlockStackConfigSub-model configuration.
This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.
- norm_type: Literal['layernorm', 'rmsnorm'] = 'layernorm'#
Type of normalization layer to use.
- logits_soft_cap: float | None = None#
Soft cap for the LM output logits. If None, no cap is applied.
- lm_head_dtype: str = 'float32'#
Data type to perform the LM Head Dense layer in. The output will always be casted to float32 for numerical stability.
- parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
- property _lm_head_dtype: jax.numpy.dtype#
Return the real dtype instead of the str in config.
- Returns:
Dtype corresponding to the respective str attribute.
- Return type:
- mlstm_block: xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.mLSTMBlockConfig | None = None#
- init_distribution_embed: xlstm_jax.models.shared.InitDistribution = 'normal'#
Distribution type from which to sample the embeddings.
- init_distribution_out: xlstm_jax.models.shared.InitDistribution = 'normal'#
Distribution type from which to sample the LM output head.
- _create_block_map()#
Creates the block map, that specifies which block is used at which position.
- Return type:
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_parallel.xlstm_lm_model.xLSTMLMModel#
Bases:
flax.linen.Module- config: xLSTMLMModelConfig#