xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton#

Classes#

mLSTMBackendRecurrentTritonConfig

Configuration class for the mLSTM recurrent backend using Triton kernels.

mLSTMBackendRecurrentTriton

mLSTM recurrent backend using Triton kernels.

Module Contents#

class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton.mLSTMBackendRecurrentTritonConfig#

Configuration class for the mLSTM recurrent backend using Triton kernels.

eps: float = 1e-06#

Epsilon value used in the kernel.

state_dtype: str | None = None#

Data type for the state tensors. If None, the data type is inferred from the input tensors.

use_scan: bool = False#

Whether to use scan for the recurrent sequence.

assign_model_config_params(model_config)#
class xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.recurrent_triton.mLSTMBackendRecurrentTriton#

Bases: xlstm_jax.models.xlstm_parallel.blocks.mlstm.backend.config.mLSTMBackend

mLSTM recurrent backend using Triton kernels.

This backend uses Triton kernels for the mLSTM recurrent cell.

Parameters:
  • config – Configuration object for the backend.

  • config_class – Configuration class for the backend.

config_class#
property can_vmap_over_heads: bool#

Whether the backend can be vmaped over the heads dimension.

Triton kernels already handle the head dimension, hence not to be vmaped over.

Returns:

False

Return type:

bool