xlstm_jax.models.xlstm_parallel.xlstm_lm_model#

Classes#

xLSTMLMModelConfig

Sub-model configuration.

xLSTMLMModel

Module Contents#

class xlstm_jax.models.xlstm_parallel.xlstm_lm_model.xLSTMLMModelConfig#

Bases: xlstm_jax.models.xlstm_parallel.xlstm_block_stack.xLSTMBlockStackConfig

Sub-model configuration.

This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.

vocab_size: int = -1#
tie_weights: bool = False#
weight_decay_on_embedding: bool = False#
add_embedding_dropout: bool = False#
norm_eps: float = 1e-06#

Epsilon value for numerical stability in normalization layer.

norm_type: Literal['layernorm', 'rmsnorm'] = 'layernorm'#

Type of normalization layer to use.

logits_soft_cap: float | None = None#

Soft cap for the LM output logits. If None, no cap is applied.

lm_head_dtype: str = 'float32'#

Data type to perform the LM Head Dense layer in. The output will always be casted to float32 for numerical stability.

parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
property _lm_head_dtype: jax.numpy.dtype#

Return the real dtype instead of the str in config.

Returns:

Dtype corresponding to the respective str attribute.

Return type:

jax.numpy.dtype

mlstm_block: xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.mLSTMBlockConfig | None = None#
slstm_block: Any | None = None#
context_length: int = -1#
num_blocks: int = 1#
embedding_dim: int = 128#
add_post_blocks_norm: bool = True#
bias: bool = False#
dropout: float = 0.0#
scan_blocks: bool = False#
dtype: str = 'bfloat16'#
init_distribution_embed: xlstm_jax.models.shared.InitDistribution = 'normal'#

Distribution type from which to sample the embeddings.

init_distribution_out: xlstm_jax.models.shared.InitDistribution = 'normal'#

Distribution type from which to sample the LM output head.

slstm_at: list[int] = []#
_block_map: str | None = None#
property block_map: list[int]#
Return type:

list[int]

_create_block_map()#

Creates the block map, that specifies which block is used at which position.

Return type:

str

property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict

class xlstm_jax.models.xlstm_parallel.xlstm_lm_model.xLSTMLMModel#

Bases: flax.linen.Module

config: xLSTMLMModelConfig#