xlstm_jax.models.xlstm_pytorch.utils#

Classes#

UpProjConfigMixin

WeightDecayOptimGroupMixin

Helper class that provides a standard way to create an ABC using

Module Contents#

class xlstm_jax.models.xlstm_pytorch.utils.UpProjConfigMixin#
proj_factor: float = None#
round_proj_up_dim_up: bool = True#
round_proj_up_to_multiple_of: int = 64#
_proj_up_dim: int = None#
_set_proj_up_dim(embedding_dim)#
Parameters:

embedding_dim (int)

Return type:

None

class xlstm_jax.models.xlstm_pytorch.utils.WeightDecayOptimGroupMixin#

Bases: torch.nn.Module, abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

get_weight_decay_optim_groups()#

Return a tuple of two sequences, one for parameters with weight decay and one for parameters without weight decay.

Performs checks to ensure that each parameter is only in one of the two sequences.

Return type:

tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]

get_weight_decay_optim_group_param_names()#

Return a tuple of two sequences, one for parameter names with weight decay and one for parameter names without weight decay.

Performs checks to ensure that each parameter is only in one of the two sequences.

Return type:

tuple[collections.abc.Sequence[str], collections.abc.Sequence[str]]

_create_weight_decay_optim_groups()#

Return a tuple of two sequences, one for parameters with weight decay and one for parameters without weight decay.

Default separation: - weight decay: all parameters which have > 1 dimensions. - no weight decay: all parameters which have = 1 dimension, e.g. biases.

Return type:

tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]

static _get_weight_decay_optim_groups_for_modules(modules, **kwargs)#
Parameters:

modules (list[WeightDecayOptimGroupMixin])

Return type:

tuple[collections.abc.Sequence[torch.nn.Parameter], collections.abc.Sequence[torch.nn.Parameter]]