xlstm_jax.kernels.mlstm_recurrent.triton_fused_fw#
Functions#
|
Module Contents#
- xlstm_jax.kernels.mlstm_recurrent.triton_fused_fw.recurrent_step_fw(matC_state, vecN_state, scaM_state, vecQ, vecK, vecV, scaI, scaF, matC_new=None, vecN_new=None, scaM_new=None, qk_scale=None, eps=1e-06)#