xlstm_jax.models.xlstm_clean.blocks.mlstm.layer#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_clean.blocks.mlstm.layer.mLSTMLayerConfig#

Bases: xlstm_jax.models.xlstm_clean.utils.UpProjConfigMixin

conv1d_kernel_size: int = 4#
qkv_proj_blocksize: int = 4#
num_heads: int = 4#
proj_factor: float = 2.0#
vmap_qk: bool = False#
embedding_dim: int = -1#
bias: bool = False#
dropout: float = 0.0#
context_length: int = -1#
dtype: str = 'bfloat16'#
_num_blocks: int = 1#
_inner_embedding_dim: int = None#
mlstm_cell: xlstm_jax.models.xlstm_clean.blocks.mlstm.cell.mLSTMCellConfig#
property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

round_proj_up_dim_up: bool = True#
round_proj_up_to_multiple_of: int = 64#
_proj_up_dim: int = None#
_set_proj_up_dim(embedding_dim)#
Parameters:

embedding_dim (int)

Return type:

None

class xlstm_jax.models.xlstm_clean.blocks.mlstm.layer.mLSTMLayer#

Bases: flax.linen.Module

config: mLSTMLayerConfig#