xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell.mLSTMCellConfig#
context_length: int = -1#
embedding_dim: int = -1#
num_heads: int = -1#
backend: xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.mLSTMBackendNameAndKwargs#
class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell.mLSTMCell(config)#

Bases: torch.nn.Module

Parameters:

config (mLSTMCellConfig)

config_class#
config#
backend_fn#
backend_fn_step#
igate#
fgate#
outnorm#
forward(q, k, v)#
Parameters:
Return type:

torch.Tensor

step(q, k, v, mlstm_state=None)#
Parameters:
Return type:

tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]

reset_parameters()#