xlstm_jax.models.xlstm_pytorch.xlstm_lm_model#
Classes#
Helper class that provides a standard way to create an ABC using |
Module Contents#
- class xlstm_jax.models.xlstm_pytorch.xlstm_lm_model.xLSTMLMModelConfig#
Bases:
xlstm_jax.models.xlstm_pytorch.xlstm_block_stack.xLSTMBlockStackConfig- mlstm_block: xlstm_jax.models.xlstm_pytorch.blocks.mlstm.block.mLSTMBlockConfig | None = None#
- slstm_block: xlstm_jax.models.xlstm_pytorch.blocks.slstm.block.sLSTMBlockConfig | None = None#
- class xlstm_jax.models.xlstm_pytorch.xlstm_lm_model.xLSTMLMModel(config, **kwargs)#
Bases:
xlstm_jax.models.xlstm_pytorch.utils.WeightDecayOptimGroupMixin,torch.nn.ModuleHelper class that provides a standard way to create an ABC using inheritance.
- Parameters:
config (xLSTMLMModelConfig)
- config_class#
- config#
- xlstm_block_stack#
- token_embedding#
- emb_dropout#
- lm_head#
- reset_parameters()#
- forward(idx)#
- Parameters:
idx (torch.Tensor)
- Return type:
- step(idx, state=None, **kwargs)#
- Parameters:
idx (torch.Tensor)
- Return type:
tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, Ellipsis]]]]
- _create_weight_decay_optim_groups(**kwargs)#
Return a tuple of two sequences, one for parameters with weight decay and one for parameters without weight decay.
Default separation: - weight decay: all parameters which have > 1 dimensions. - no weight decay: all parameters which have = 1 dimension, e.g. biases.
- Return type:
tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]
- get_weight_decay_optim_groups()#
Return a tuple of two sequences, one for parameters with weight decay and one for parameters without weight decay.
Performs checks to ensure that each parameter is only in one of the two sequences.
- Return type:
tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]
- get_weight_decay_optim_group_param_names()#
Return a tuple of two sequences, one for parameter names with weight decay and one for parameter names without weight decay.
Performs checks to ensure that each parameter is only in one of the two sequences.
- Return type:
tuple[collections.abc.Sequence[str], collections.abc.Sequence[str]]
- static _get_weight_decay_optim_groups_for_modules(modules, **kwargs)#
- Parameters:
modules (list[WeightDecayOptimGroupMixin])
- Return type:
tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]