xlstm_jax.models.xlstm_pytorch.xlstm_block_stack#
Classes#
Module Contents#
- class xlstm_jax.models.xlstm_pytorch.xlstm_block_stack.xLSTMBlockStackConfig#
- mlstm_block: xlstm_jax.models.xlstm_pytorch.blocks.mlstm.block.mLSTMBlockConfig | None = None#
- slstm_block: xlstm_jax.models.xlstm_pytorch.blocks.slstm.block.sLSTMBlockConfig | None = None#
- class xlstm_jax.models.xlstm_pytorch.xlstm_block_stack.xLSTMBlockStack(config)#
Bases:
torch.nn.Module- Parameters:
config (xLSTMBlockStackConfig)
- config_class#
- config#
- blocks#
- _create_blocks(config)#
- Parameters:
config (xLSTMBlockStackConfig)
- reset_parameters()#
- Return type:
None
- forward(x, **kwargs)#
- Parameters:
x (torch.Tensor)
- Return type:
- step(x, state=None)#
- Parameters:
x (torch.Tensor)
- Return type:
tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, Ellipsis]]]]