xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block#
Classes#
An xLSTM block can be either an sLSTM Block or an mLSTM Block. |
Module Contents#
- class xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block.xLSTMBlockConfig#
-
- feedforward: xlstm_jax.models.xlstm_pytorch.components.feedforward.FeedForwardConfig | None = None#
- class xlstm_jax.models.xlstm_pytorch.blocks.xlstm_block.xLSTMBlock(config)#
Bases:
torch.nn.ModuleAn xLSTM block can be either an sLSTM Block or an mLSTM Block.
It contains the pre-LayerNorms and the skip connections.
- Parameters:
config (xLSTMBlockConfig)
- config_class#
- config#
- xlstm_norm#
- forward(x, **kwargs)#
- Parameters:
x (torch.Tensor)
- Return type:
- step(x, **kwargs)#
- Parameters:
x (torch.Tensor)
- Return type:
tuple[torch.Tensor, dict[str, tuple[torch.Tensor, Ellipsis]]]
- reset_parameters()#
- Return type:
None