xlstm_jax.models.xlstm_clean.xlstm_lm_model#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_clean.xlstm_lm_model.xLSTMLMModelConfig#

Bases: xlstm_jax.models.xlstm_clean.xlstm_block_stack.xLSTMBlockStackConfig

vocab_size: int = -1#
tie_weights: bool = False#
weight_decay_on_embedding: bool = False#
add_embedding_dropout: bool = False#
mlstm_block: xlstm_jax.models.xlstm_clean.blocks.mlstm.block.mLSTMBlockConfig | None = None#
slstm_block: None = None#
context_length: int = -1#
num_blocks: int = 1#
embedding_dim: int = 128#
add_post_blocks_norm: bool = True#
bias: bool = False#
dropout: float = 0.0#
dtype: str = 'bfloat16'#
slstm_at: list[int] | Literal['all'] = []#
_block_map: str = None#
property block_map: list[int]#
Return type:

list[int]

_create_block_map()#

Creates the block map, that specifies which block is used at which position.

Return type:

str

property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

class xlstm_jax.models.xlstm_clean.xlstm_lm_model.xLSTMLMModel#

Bases: flax.linen.Module

config: xLSTMLMModelConfig#