xlstm_jax.models.xlstm_pytorch.xlstm_block_stack#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_pytorch.xlstm_block_stack.xLSTMBlockStackConfig#
mlstm_block: xlstm_jax.models.xlstm_pytorch.blocks.mlstm.block.mLSTMBlockConfig | None = None#
slstm_block: xlstm_jax.models.xlstm_pytorch.blocks.slstm.block.sLSTMBlockConfig | None = None#
context_length: int = -1#
num_blocks: int = 1#
embedding_dim: int = 128#
add_post_blocks_norm: bool = True#
bias: bool = False#
dropout: float = 0.0#
slstm_at: list[int] | Literal['all'] = []#
_block_map: str = None#
property block_map: list[int]#
Return type:

list[int]

_create_block_map()#

Creates the block map, that specifies which block is used at which position.

Return type:

str

class xlstm_jax.models.xlstm_pytorch.xlstm_block_stack.xLSTMBlockStack(config)#

Bases: torch.nn.Module

Parameters:

config (xLSTMBlockStackConfig)

config_class#
config#
blocks#
_create_blocks(config)#
Parameters:

config (xLSTMBlockStackConfig)

reset_parameters()#
Return type:

None

forward(x, **kwargs)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

step(x, state=None)#
Parameters:
Return type:

tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, Ellipsis]]]]