xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels#

Attributes#

Classes#

Module Contents#

xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels.BackendNameType#
class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels.mLSTMBackendTritonConfig#
autocast_dtype: str | None = None#

Dtype to use for the kernel computation. If None, uses the query dtype.

chunk_size: int = 64#

Chunk size for the kernel computation.

reduce_slicing: bool = True#

Whether to reduce slicing operations before the kernel computation. Speeds up computation during training, but may limit initial states and forwarding states during inference.

backend_name: BackendNameType = 'max_triton_noslice'#

Backend name for the kernel type used

eps: float = 1e-06#

Epsilon value used in the kernel

norm_val: float = 1.0#

Normalizer upper bound value - max(norm_val e^-m, |n q|)

stabilize_correctly: bool = True#

Whether to stabilize correctly, i.e. scale norm_val with the maximizer state - see above

assign_model_config_params(model_config)#
class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels.mLSTMBackendTriton#

Bases: xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend

config_class#
property can_vmap_over_heads: bool#

Whether the backend can be vmaped over the heads dimension.

Triton kernels already handle the head dimension, hence not to be vmaped over.

Returns:

False

Return type:

bool

config: Any#