xlstm_jax.models.xlstm_clean.components.linear_headwise#
Classes#
This is a structured projection layer that projects the input to a higher dimension. |
Module Contents#
- class xlstm_jax.models.xlstm_clean.components.linear_headwise.LinearHeadwiseExpandConfig#
-
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_clean.components.linear_headwise.LinearHeadwiseExpand#
Bases:
flax.linen.ModuleThis is a structured projection layer that projects the input to a higher dimension.
It only allows integer up-projection factors, i.e. the output dimension is a multiple of the input dimension.
- config: LinearHeadwiseExpandConfig#
- kernel_init: Any = None#
- bias_init: callable#
- extra_repr()#