xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer#
Classes#
Module Contents#
- class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer.mLSTMLayerConfig#
Bases:
xlstm_jax.models.xlstm_pytorch.utils.UpProjConfigMixin
- class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer.mLSTMLayer(config)#
Bases:
torch.nn.Module- Parameters:
config (mLSTMLayerConfig)
- config_class#
- config#
- proj_up#
- q_proj#
- k_proj#
- v_proj#
- conv1d#
- conv_act_fn#
- mlstm_cell#
- ogate_act_fn#
- learnable_skip#
- proj_down#
- dropout#
- forward(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- step(x, mlstm_state=None, conv_state=None)#
- Parameters:
x (torch.Tensor)
mlstm_state (tuple[torch.Tensor, torch.Tensor, torch.Tensor])
conv_state (tuple[torch.Tensor])
- Return type:
tuple[torch.Tensor, dict[str, tuple[torch.Tensor, Ellipsis]]]
- reset_parameters()#