xlstm_jax.models.xlstm_clean.components.feedforward#

Attributes#

Classes#

Functions#

get_act_fn(act_fn_name)

create_feedforward(config)

Module Contents#

xlstm_jax.models.xlstm_clean.components.feedforward._act_fn_registry#
xlstm_jax.models.xlstm_clean.components.feedforward.get_act_fn(act_fn_name)#
Parameters:

act_fn_name (str)

Return type:

collections.abc.Callable[[jax.Array], jax.Array]

class xlstm_jax.models.xlstm_clean.components.feedforward.FeedForwardConfig#

Bases: xlstm_jax.models.xlstm_clean.utils.UpProjConfigMixin

proj_factor: float = 1.3#
act_fn: str = 'gelu'#
embedding_dim: int = -1#
dropout: float = 0.0#
bias: bool = False#
ff_type: Literal['ffn_gated'] = 'ffn_gated'#
dtype: str = 'bfloat16'#
_num_blocks: int = 1#
property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

round_proj_up_dim_up: bool = True#
round_proj_up_to_multiple_of: int = 64#
_proj_up_dim: int = None#
_set_proj_up_dim(embedding_dim)#
Parameters:

embedding_dim (int)

Return type:

None

class xlstm_jax.models.xlstm_clean.components.feedforward.GatedFeedForward#

Bases: flax.linen.Module

config: FeedForwardConfig#
xlstm_jax.models.xlstm_clean.components.feedforward.create_feedforward(config)#
Parameters:

config (FeedForwardConfig)

Return type:

flax.linen.Module