xlstm_jax.models.xlstm_pytorch.components.conv#
Classes#
Functions#
|
B: batch size |
Module Contents#
- class xlstm_jax.models.xlstm_pytorch.components.conv.CausalConv1dConfig#
- xlstm_jax.models.xlstm_pytorch.components.conv.conv1d_step(x, conv_state, conv1d_weight, conv1d_bias=None)#
B: batch size S: sequence length D: feature dimension KS: kernel size :param x: (B, S, D) :param conv_state: (B, KS, D) :param conv1d_weight: (KS, D) :param conv1d_bias:
- Parameters:
x (torch.Tensor)
conv_state (torch.Tensor)
conv1d_weight (torch.Tensor)
conv1d_bias (torch.Tensor)
- Return type:
- class xlstm_jax.models.xlstm_pytorch.components.conv.CausalConv1d(config)#
Bases:
torch.nn.Module- Parameters:
config (CausalConv1dConfig)
- config_class#
Causal depth-wise convolution of a time series tensor.
Input: Tensor of shape (B,T,F), i.e. (batch, time, feature) Output: Tensor of shape (B,T,F)
- Parameters:
feature_dim – Number of features in the input tensor.
kernel_size – Size of the kernel for the depth-wise convolution.
causal_conv_bias – Whether to use bias in the depth-wise convolution
channel_mixing – Whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim). If True, it mixes the convolved features across channels. If False, all the features are convolved independently.
- config#
- groups#
- reset_parameters()#
- _create_weight_decay_optim_groups()#
- forward(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- step(x, conv_state=None)#
- Parameters:
x (torch.Tensor)
conv_state (tuple[torch.Tensor])
- Return type: