xlstm_jax.models.xlstm_parallel.components.feedforward#
Attributes#
Classes#
Sub-model configuration. |
|
Functions#
|
|
|
Module Contents#
- xlstm_jax.models.xlstm_parallel.components.feedforward._act_fn_registry#
- xlstm_jax.models.xlstm_parallel.components.feedforward.get_act_fn(act_fn_name)#
- Parameters:
act_fn_name (str)
- Return type:
- class xlstm_jax.models.xlstm_parallel.components.feedforward.FeedForwardConfig#
Bases:
xlstm_jax.models.xlstm_parallel.utils.UpProjConfigMixinSub-model configuration.
This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.
- init_distribution: xlstm_jax.models.shared.InitDistribution = 'normal'#
Distribution type from which to sample the weights.
- output_init_fn: xlstm_jax.models.shared.InitFnName = 'wang'#
Initialization function for the output projection layer.
- ff_type: Literal['ffn_gated', 'ffn'] = 'ffn_gated'#
- parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.xlstm_parallel.components.feedforward.GatedFeedForward#
Bases:
flax.linen.Module- config: FeedForwardConfig#
- class xlstm_jax.models.xlstm_parallel.components.feedforward.FeedForward#
Bases:
flax.linen.Module- config: FeedForwardConfig#
- xlstm_jax.models.xlstm_parallel.components.feedforward.create_feedforward(config, name='ffn')#
- Parameters:
config (FeedForwardConfig)
name (str)
- Return type:
flax.linen.Module