xlstm_jax.models.xlstm_clean.components.conv#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_clean.components.conv.CausalConv1dConfig#
feature_dim: int = None#
kernel_size: int = 4#
causal_conv_bias: bool = True#
channel_mixing: bool = False#
conv1d_kwargs: dict#
dtype: str = 'bfloat16'#
property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

class xlstm_jax.models.xlstm_clean.components.conv.CausalConv1d#

Bases: flax.linen.Module

config: CausalConv1dConfig#

Implements causal depthwise convolution of a time series tensor. Input: Tensor of shape (B,T,F), i.e. (batch, time, feature) Output: Tensor of shape (B,T,F)

Parameters:
  • feature_dim – number of features in the input tensor

  • kernel_size – size of the kernel for the depthwise convolution

  • causal_conv_bias – whether to use bias in the depthwise convolution

  • channel_mixing – whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim) If True, it mixes the convolved features across channels. If False, all the features are convolved independently.