xlstm_jax.models.xlstm_parallel.components.feedforward#

Attributes#

Classes#

Functions#

get_act_fn(act_fn_name)

create_feedforward(config[, name])

Module Contents#

xlstm_jax.models.xlstm_parallel.components.feedforward._act_fn_registry#
xlstm_jax.models.xlstm_parallel.components.feedforward.get_act_fn(act_fn_name)#
Parameters:

act_fn_name (str)

Return type:

collections.abc.Callable[[jax.Array], jax.Array]

class xlstm_jax.models.xlstm_parallel.components.feedforward.FeedForwardConfig#

Bases: xlstm_jax.models.xlstm_parallel.utils.UpProjConfigMixin

Sub-model configuration.

This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.

proj_factor: float = 1.3#
act_fn: str = 'gelu'#
embedding_dim: int = -1#
dropout: float = 0.0#
bias: bool = False#
init_distribution: xlstm_jax.models.shared.InitDistribution = 'normal'#

Distribution type from which to sample the weights.

output_init_fn: xlstm_jax.models.shared.InitFnName = 'wang'#

Initialization function for the output projection layer.

ff_type: Literal['ffn_gated', 'ffn'] = 'ffn_gated'#
dtype: str = 'bfloat16'#
parallel: xlstm_jax.models.configs.ParallelConfig | None = None#
_num_blocks: int = 1#
property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

round_proj_up_dim_up: bool = True#
round_proj_up_to_multiple_of: int = 64#
_proj_up_dim: int | None = None#
_set_proj_up_dim(embedding_dim)#
Parameters:

embedding_dim (int)

Return type:

None

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict

class xlstm_jax.models.xlstm_parallel.components.feedforward.GatedFeedForward#

Bases: flax.linen.Module

config: FeedForwardConfig#
class xlstm_jax.models.xlstm_parallel.components.feedforward.FeedForward#

Bases: flax.linen.Module

config: FeedForwardConfig#
xlstm_jax.models.xlstm_parallel.components.feedforward.create_feedforward(config, name='ffn')#
Parameters:
Return type:

flax.linen.Module