xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block#

Classes#

xLSTMBlockConfig

xLSTMBlock

An xLSTM block can be either an sLSTM Block or an mLSTM Block.

Module Contents#

class xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block.xLSTMBlockConfig#
mlstm: xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer.mLSTMLayerConfig | None = None#
slstm: xlstm_jax.models.xlstm_pytorch.blocks.slstm.layer.sLSTMLayerConfig | None = None#
feedforward: xlstm_jax.models.xlstm_pytorch.components.feedforward.FeedForwardConfig | None = None#
_num_blocks: int | None = None#
_block_idx: int | None = None#
class xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block.xLSTMBlock(config)#

Bases: torch.nn.Module

An xLSTM block can be either an sLSTM Block or an mLSTM Block.

It contains the pre-LayerNorms and the skip connections.

Parameters:

config (xLSTMBlockConfig)

config_class#
config#
xlstm_norm#
forward(x, **kwargs)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

step(x, **kwargs)#
Parameters:

x (torch.Tensor)

Return type:

tuple[torch.Tensor, dict[str, tuple[torch.Tensor, Ellipsis]]]

reset_parameters()#
Return type:

None