xlstm_jax.kernels.mlstm_recurrent.triton_fused_fw

xlstm_jax.kernels.mlstm_recurrent.triton_fused_fw#

Functions#

recurrent_step_fw(matC_state, vecN_state, scaM_state, ...)

Module Contents#

xlstm_jax.kernels.mlstm_recurrent.triton_fused_fw.recurrent_step_fw(matC_state, vecN_state, scaM_state, vecQ, vecK, vecV, scaI, scaF, matC_new=None, vecN_new=None, scaM_new=None, qk_scale=None, eps=1e-06)#
Parameters: