xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw#
Classes#
Functions#
|
|
|
|
|
Returns an autograd function that computes the gradient itself. |
Module Contents#
- xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.rev_cumsum_off(x)#
- xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.rev_cumsum(x)#
- xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.causal_forget_matrix(forget_gates)#
- class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.mLSTMfwbwConfig#
-
- scale = None#
- assign_model_config_params(model_config)#
- xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.mLSTMTorchFunction(config)#
Returns an autograd function that computes the gradient itself.
config: mLSTMfwbwConfig Configuration for mLSTMTorchFunc
- Parameters:
config (mLSTMfwbwConfig)
- class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.mLSTMfwbw(config)#
Bases:
torch.nn.Module- Parameters:
config (mLSTMfwbwConfig)
- config_class#
- config#
- func#
- forward(*args)#