xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla#
Submodules#
Attributes#
Functions#
|
|
|
Package Contents#
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm_pointwise_function_registry: dict[str, collections.abc.Callable]#
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm_forward(x, states, R, b, pointwise_forward, constants=None)#
- Parameters:
x (torch.Tensor)
states (torch.Tensor)
R (torch.Tensor)
b (torch.Tensor)
pointwise_forward (collections.abc.Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]], tuple[torch.Tensor, torch.Tensor]])
- Return type:
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.src.vanilla.slstm_forward_step(x, states, R, b, pointwise_forward, constants=None)#
- Parameters:
x (torch.Tensor)
states (torch.Tensor)
R (torch.Tensor)
b (torch.Tensor)
pointwise_forward (collections.abc.Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]], tuple[torch.Tensor, torch.Tensor]])
- Return type: