xlstm_jax.models.xlstm_pytorch.components.conv#

Classes#

Functions#

conv1d_step(x, conv_state, conv1d_weight[, conv1d_bias])

B: batch size

Module Contents#

class xlstm_jax.models.xlstm_pytorch.components.conv.CausalConv1dConfig#
feature_dim: int = None#
kernel_size: int = 4#
causal_conv_bias: bool = True#
channel_mixing: bool = False#
conv1d_kwargs: dict#
xlstm_jax.models.xlstm_pytorch.components.conv.conv1d_step(x, conv_state, conv1d_weight, conv1d_bias=None)#

B: batch size S: sequence length D: feature dimension KS: kernel size :param x: (B, S, D) :param conv_state: (B, KS, D) :param conv1d_weight: (KS, D) :param conv1d_bias:

Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

class xlstm_jax.models.xlstm_pytorch.components.conv.CausalConv1d(config)#

Bases: torch.nn.Module

Parameters:

config (CausalConv1dConfig)

config_class#

Causal depth-wise convolution of a time series tensor.

Input: Tensor of shape (B,T,F), i.e. (batch, time, feature) Output: Tensor of shape (B,T,F)

Parameters:
  • feature_dim – Number of features in the input tensor.

  • kernel_size – Size of the kernel for the depth-wise convolution.

  • causal_conv_bias – Whether to use bias in the depth-wise convolution

  • channel_mixing – Whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim). If True, it mixes the convolved features across channels. If False, all the features are convolved independently.

config#
groups#
reset_parameters()#
_create_weight_decay_optim_groups()#
Return type:

tuple[set[torch.nn.Parameter], set[torch.nn.Parameter]]

forward(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

step(x, conv_state=None)#
Parameters:
Return type:

tuple[torch.Tensor, tuple[torch.Tensor]]