xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels#
Attributes#
Classes#
Module Contents#
- xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels.BackendNameType#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels.mLSTMBackendTritonConfig#
- autocast_dtype: str | None = None#
Dtype to use for the kernel computation. If None, uses the query dtype.
- reduce_slicing: bool = True#
Whether to reduce slicing operations before the kernel computation. Speeds up computation during training, but may limit initial states and forwarding states during inference.
- backend_name: BackendNameType = 'max_triton_noslice'#
Backend name for the kernel type used
- stabilize_correctly: bool = True#
Whether to stabilize correctly, i.e. scale norm_val with the maximizer state - see above
- assign_model_config_params(model_config)#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.triton_kernels.mLSTMBackendTriton#
Bases:
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend- config_class#
- property can_vmap_over_heads: bool#
Whether the backend can be vmaped over the heads dimension.
Triton kernels already handle the head dimension, hence not to be vmaped over.
- Returns:
False
- Return type:
- config: Any#