xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer#
Classes#
Module Contents#
- class xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer.sLSTMLayerConfig#
Bases:
xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellConfig
- class xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer.sLSTMLayer(config)#
Bases:
torch.nn.Module- Parameters:
config (sLSTMLayerConfig)
- config_class#
- config#
- fgate#
- igate#
- zgate#
- ogate#
- slstm_cell#
- group_norm#
- dropout#
- reset_parameters()#
- forward(x, initial_state=None, return_last_state=False)#
- Parameters:
x (torch.Tensor)
initial_state (torch.Tensor | None)
- Return type: