xlstm_jax.models.xlstm_clean.components.conv#
Classes#
Module Contents#
- class xlstm_jax.models.xlstm_clean.components.conv.CausalConv1dConfig#
-
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_clean.components.conv.CausalConv1d#
Bases:
flax.linen.Module- config: CausalConv1dConfig#
Implements causal depthwise convolution of a time series tensor. Input: Tensor of shape (B,T,F), i.e. (batch, time, feature) Output: Tensor of shape (B,T,F)
- Parameters:
feature_dim – number of features in the input tensor
kernel_size – size of the kernel for the depthwise convolution
causal_conv_bias – whether to use bias in the depthwise convolution
channel_mixing – whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim) If True, it mixes the convolved features across channels. If False, all the features are convolved independently.