xlstm_jax.models.xlstm_pytorch.xlstm_lm_model#

Classes#

xLSTMLMModelConfig

xLSTMLMModel

Helper class that provides a standard way to create an ABC using

Module Contents#

class xlstm_jax.models.xlstm_pytorch.xlstm_lm_model.xLSTMLMModelConfig#

Bases: xlstm_jax.models.xlstm_pytorch.xlstm_block_stack.xLSTMBlockStackConfig

vocab_size: int = -1#
tie_weights: bool = False#
weight_decay_on_embedding: bool = False#
add_embedding_dropout: bool = False#
mlstm_block: xlstm_jax.models.xlstm_pytorch.blocks.mlstm.block.mLSTMBlockConfig | None = None#
slstm_block: xlstm_jax.models.xlstm_pytorch.blocks.slstm.block.sLSTMBlockConfig | None = None#
context_length: int = -1#
num_blocks: int = 1#
embedding_dim: int = 128#
add_post_blocks_norm: bool = True#
bias: bool = False#
dropout: float = 0.0#
slstm_at: list[int] | Literal['all'] = []#
_block_map: str = None#
property block_map: list[int]#
Return type:

list[int]

_create_block_map()#

Creates the block map, that specifies which block is used at which position.

Return type:

str

class xlstm_jax.models.xlstm_pytorch.xlstm_lm_model.xLSTMLMModel(config, **kwargs)#

Bases: xlstm_jax.models.xlstm_pytorch.utils.WeightDecayOptimGroupMixin, torch.nn.Module

Helper class that provides a standard way to create an ABC using inheritance.

Parameters:

config (xLSTMLMModelConfig)

config_class#
config#
xlstm_block_stack#
token_embedding#
emb_dropout#
lm_head#
reset_parameters()#
forward(idx)#
Parameters:

idx (torch.Tensor)

Return type:

torch.Tensor

step(idx, state=None, **kwargs)#
Parameters:
Return type:

tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, Ellipsis]]]]

_create_weight_decay_optim_groups(**kwargs)#

Return a tuple of two sequences, one for parameters with weight decay and one for parameters without weight decay.

Default separation: - weight decay: all parameters which have > 1 dimensions. - no weight decay: all parameters which have = 1 dimension, e.g. biases.

Return type:

tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]

get_weight_decay_optim_groups()#

Return a tuple of two sequences, one for parameters with weight decay and one for parameters without weight decay.

Performs checks to ensure that each parameter is only in one of the two sequences.

Return type:

tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]

get_weight_decay_optim_group_param_names()#

Return a tuple of two sequences, one for parameter names with weight decay and one for parameter names without weight decay.

Performs checks to ensure that each parameter is only in one of the two sequences.

Return type:

tuple[collections.abc.Sequence[str], collections.abc.Sequence[str]]

static _get_weight_decay_optim_groups_for_modules(modules, **kwargs)#
Parameters:

modules (list[WeightDecayOptimGroupMixin])

Return type:

tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]