xlstm_jax.models.xlstm_pytorch.components.linear_headwise#
Classes#
This is a structured projection layer that projects the input to a higher dimension. |
Module Contents#
- class xlstm_jax.models.xlstm_pytorch.components.linear_headwise.LinearHeadwiseExpandConfig#
- class xlstm_jax.models.xlstm_pytorch.components.linear_headwise.LinearHeadwiseExpand(config)#
Bases:
torch.nn.ModuleThis is a structured projection layer that projects the input to a higher dimension.
It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension.
- Parameters:
config (LinearHeadwiseExpandConfig)
- config_class#
- config#
- weight#
- reset_parameters()#
- forward(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- extra_repr()#