xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell#
Classes#
Module Contents#
- class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell.mLSTMCellConfig#
- class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell.mLSTMCell(config)#
Bases:
torch.nn.Module- Parameters:
config (mLSTMCellConfig)
- config_class#
- config#
- backend_fn#
- backend_fn_step#
- igate#
- fgate#
- outnorm#
- forward(q, k, v)#
- Parameters:
q (torch.Tensor)
k (torch.Tensor)
v (torch.Tensor)
- Return type:
- step(q, k, v, mlstm_state=None)#
- Parameters:
q (torch.Tensor)
k (torch.Tensor)
v (torch.Tensor)
mlstm_state (tuple[torch.Tensor, torch.Tensor, torch.Tensor])
- Return type:
tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
- reset_parameters()#