xlstm_jax.models.xlstm_clean.blocks.mlstm.cell#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_clean.blocks.mlstm.cell.mLSTMCellConfig#
context_length: int = -1#
embedding_dim: int = -1#
num_heads: int = -1#
backend: xlstm_jax.models.xlstm_clean.blocks.mlstm.backend.mLSTMBackendNameAndKwargs#
dtype: str = 'bfloat16'#
property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

class xlstm_jax.models.xlstm_clean.blocks.mlstm.cell.mLSTMCell#

Bases: flax.linen.Module

config: mLSTMCellConfig#