xlstm_jax.models.xlstm_parallel.xlstm_block_stack#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_parallel.xlstm_block_stack.xLSTMBlockStackConfig#

Bases: xlstm_jax.models.configs.SubModelConfig

Sub-model configuration.

This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.

mlstm_block: xlstm_jax.models.xlstm_parallel.blocks.mlstm.block.mLSTMBlockConfig | None = None#
slstm_block: Any | None = None#
context_length: int = -1#
num_blocks: int = 1#
embedding_dim: int = 128#
add_post_blocks_norm: bool = True#
bias: bool = False#
dropout: float = 0.0#
scan_blocks: bool = False#
dtype: str = 'bfloat16'#
parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
init_distribution_embed: xlstm_jax.models.shared.InitDistribution = 'normal'#

Distribution type from which to sample the embeddings.

init_distribution_out: xlstm_jax.models.shared.InitDistribution = 'normal'#

Distribution type from which to sample the LM output head.

slstm_at: list[int] = []#
_block_map: str | None = None#
property block_map: list[int]#
Return type:

list[int]

_create_block_map()#

Creates the block map, that specifies which block is used at which position.

Return type:

str

property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict

class xlstm_jax.models.xlstm_parallel.xlstm_block_stack.xLSTMBlockStack#

Bases: flax.linen.Module

config: xLSTMBlockStackConfig#
class xlstm_jax.models.xlstm_parallel.xlstm_block_stack.BlockStack#

Bases: flax.linen.Module

config: xLSTMBlockStackConfig#
_create_blocks(config)#
Parameters:

config (xLSTMBlockStackConfig)