xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer#

Classes#

Module Contents#

class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer.mLSTMLayerConfig#

Bases: xlstm_jax.models.xlstm_pytorch.utils.UpProjConfigMixin

conv1d_kernel_size: int = 4#
qkv_proj_blocksize: int = 4#
num_heads: int = 4#
proj_factor: float = 2.0#
embedding_dim: int = -1#
bias: bool = False#
dropout: float = 0.0#
context_length: int = -1#
_num_blocks: int = 1#
_inner_embedding_dim: int = None#
mlstm_cell: xlstm_jax.models.xlstm_pytorch.blocks.mlstm.cell.mLSTMCellConfig#
round_proj_up_dim_up: bool = True#
round_proj_up_to_multiple_of: int = 64#
_proj_up_dim: int = None#
_set_proj_up_dim(embedding_dim)#
Parameters:

embedding_dim (int)

Return type:

None

class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.layer.mLSTMLayer(config)#

Bases: torch.nn.Module

Parameters:

config (mLSTMLayerConfig)

config_class#
config#
proj_up#
q_proj#
k_proj#
v_proj#
conv1d#
conv_act_fn#
mlstm_cell#
ogate_act_fn#
learnable_skip#
proj_down#
dropout#
forward(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

step(x, mlstm_state=None, conv_state=None)#
Parameters:
Return type:

tuple[torch.Tensor, dict[str, tuple[torch.Tensor, Ellipsis]]]

reset_parameters()#