xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla#

Submodules#

Attributes#

Functions#

slstm_forward(x, states, R, b, pointwise_forward[, ...])

slstm_forward_step(x, states, R, b, pointwise_forward)

Package Contents#

xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm_pointwise_function_registry: dict[str, collections.abc.Callable]#
xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm_forward(x, states, R, b, pointwise_forward, constants=None)#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm_forward_step(x, states, R, b, pointwise_forward, constants=None)#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]