xlstm_jax.models.xlstm_pytorch.components.feedforward#

Attributes#

Classes#

Functions#

get_act_fn(act_fn_name)

create_feedforward(config)

Module Contents#

xlstm_jax.models.xlstm_pytorch.components.feedforward._act_fn_registry#
xlstm_jax.models.xlstm_pytorch.components.feedforward.get_act_fn(act_fn_name)#
Parameters:

act_fn_name (str)

Return type:

collections.abc.Callable[[torch.Tensor], torch.Tensor]

class xlstm_jax.models.xlstm_pytorch.components.feedforward.FeedForwardConfig#

Bases: xlstm_jax.models.xlstm_pytorch.utils.UpProjConfigMixin

proj_factor: float = 1.3#
act_fn: str = 'gelu'#
embedding_dim: int = -1#
dropout: float = 0.0#
bias: bool = False#
ff_type: Literal['ffn_gated'] = 'ffn_gated'#
_num_blocks: int = 1#
round_proj_up_dim_up: bool = True#
round_proj_up_to_multiple_of: int = 64#
_proj_up_dim: int = None#
_set_proj_up_dim(embedding_dim)#
Parameters:

embedding_dim (int)

Return type:

None

class xlstm_jax.models.xlstm_pytorch.components.feedforward.GatedFeedForward(config)#

Bases: torch.nn.Module

Parameters:

config (FeedForwardConfig)

config_class#
config#
proj_up#
proj_down#
act_fn#
dropout#
forward(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

reset_parameters()#
xlstm_jax.models.xlstm_pytorch.components.feedforward.create_feedforward(config)#
Parameters:

config (FeedForwardConfig)

Return type:

torch.nn.Module