xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer.sLSTMLayerConfig#

Bases: xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellConfig

embedding_dim: int = -1#
num_heads: int = 4#
conv1d_kernel_size: int = 4#
group_norm_weight: bool = True#
dropout: float = 0.0#
class xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer.sLSTMLayer(config)#

Bases: torch.nn.Module

Parameters:

config (sLSTMLayerConfig)

config_class#
config#
fgate#
igate#
zgate#
ogate#
slstm_cell#
group_norm#
dropout#
reset_parameters()#
forward(x, initial_state=None, return_last_state=False)#
Parameters:
Return type:

torch.Tensor