xlstm_jax.models.xlstm_pytorch.components.linear_headwise#

Classes#

LinearHeadwiseExpandConfig

LinearHeadwiseExpand

This is a structured projection layer that projects the input to a higher dimension.

Module Contents#

class xlstm_jax.models.xlstm_pytorch.components.linear_headwise.LinearHeadwiseExpandConfig#
in_features: int = 0#
num_heads: int = -1#
expand_factor_up: float = 1#
_out_features: int = -1#
bias: bool = True#
trainable_weight: bool = True#
trainable_bias: bool = True#
class xlstm_jax.models.xlstm_pytorch.components.linear_headwise.LinearHeadwiseExpand(config)#

Bases: torch.nn.Module

This is a structured projection layer that projects the input to a higher dimension.

It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension.

Parameters:

config (LinearHeadwiseExpandConfig)

config_class#
config#
weight#
reset_parameters()#
forward(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

extra_repr()#