xlstm_jax.models.xlstm_parallel.components.linear_headwise#

Classes#

LinearHeadwiseExpandConfig

LinearHeadwiseExpand

This is a structured projection layer that projects the input to a higher dimension.

Module Contents#

class xlstm_jax.models.xlstm_parallel.components.linear_headwise.LinearHeadwiseExpandConfig#
in_features: int = 0#
num_heads: int = -1#
expand_factor_up: float = 1#
_out_features: int = -1#
bias: bool = True#
trainable_weight: bool = True#
trainable_bias: bool = True#
dtype: str = 'float32'#
property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

class xlstm_jax.models.xlstm_parallel.components.linear_headwise.LinearHeadwiseExpand#

Bases: flax.linen.Module

This is a structured projection layer that projects the input to a higher dimension.

It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension.

config: LinearHeadwiseExpandConfig#
kernel_init: Any = None#
bias_init: callable#
extra_repr()#