xlstm_jax.models.xlstm_clean.xlstm_lm_model#
Classes#
Module Contents#
- class xlstm_jax.models.xlstm_clean.xlstm_lm_model.xLSTMLMModelConfig#
Bases:
xlstm_jax.models.xlstm_clean.xlstm_block_stack.xLSTMBlockStackConfig- mlstm_block: xlstm_jax.models.xlstm_clean.blocks.mlstm.block.mLSTMBlockConfig | None = None#
- _create_block_map()#
Creates the block map, that specifies which block is used at which position.
- Return type:
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_clean.xlstm_lm_model.xLSTMLMModel#
Bases:
flax.linen.Module- config: xLSTMLMModelConfig#