xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton#
Classes#
Configuration class for the mLSTM recurrent backend using Triton kernels. |
|
mLSTM recurrent backend using Triton kernels. |
Module Contents#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton.mLSTMBackendRecurrentTritonConfig#
Configuration class for the mLSTM recurrent backend using Triton kernels.
- state_dtype: str | None = None#
Data type for the state tensors. If None, the data type is inferred from the input tensors.
- assign_model_config_params(model_config)#
- class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton.mLSTMBackendRecurrentTriton#
Bases:
xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackendmLSTM recurrent backend using Triton kernels.
This backend uses Triton kernels for the mLSTM recurrent cell.
- Parameters:
config – Configuration object for the backend.
config_class – Configuration class for the backend.
- config_class#