xlstm_jax.models.xlstm_clean.components.feedforward#
Attributes#
Classes#
Functions#
|
|
|
Module Contents#
- xlstm_jax.models.xlstm_clean.components.feedforward._act_fn_registry#
- xlstm_jax.models.xlstm_clean.components.feedforward.get_act_fn(act_fn_name)#
- Parameters:
act_fn_name (str)
- Return type:
- class xlstm_jax.models.xlstm_clean.components.feedforward.FeedForwardConfig#
Bases:
xlstm_jax.models.xlstm_clean.utils.UpProjConfigMixin- ff_type: Literal['ffn_gated'] = 'ffn_gated'#
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_clean.components.feedforward.GatedFeedForward#
Bases:
flax.linen.Module- config: FeedForwardConfig#
- xlstm_jax.models.xlstm_clean.components.feedforward.create_feedforward(config)#
- Parameters:
config (FeedForwardConfig)
- Return type:
flax.linen.Module