xlstm_jax.trainer.optimizer.scheduler#

Classes#

SchedulerConfig

Configuration for learning rate scheduler.

Functions#

build_lr_scheduler(scheduler_config)

Build learning rate schedule from config.

Module Contents#

class xlstm_jax.trainer.optimizer.scheduler.SchedulerConfig#

Bases: xlstm_jax.configs.ConfigDict

Configuration for learning rate scheduler.

lr#

Initial/peak learning rate of the main scheduler.

Type:

float

name#

Name of the learning rate schedule. The supported schedules are “constant”, “cosine_decay”, “exponential_decay”, and “linear”.

Type:

Literal

decay_steps#

Number of steps for the learning rate schedule, including warmup and cooldown. If not provided, it is defined at runtime in the start script.

Type:

int | None

end_lr#

Final learning rate before the cooldown. This is mutually exclusive with end_lr_factor.

Type:

float | None

end_lr_factor#

Factor to multiply initial learning rate to get final learning rate before the cooldown. This is mutually exclusive with end_lr.

Type:

float | None

cooldown_steps#

Number of steps for cooldown.

Type:

int

warmup_steps#

Number of steps for warmup.

Type:

int

cooldown_lr#

Final learning rate for cooldown.

Type:

float

lr: float#
name: str = 'constant'#
decay_steps: int | None = 0#
end_lr: float | None = None#
end_lr_factor: float | None = None#
cooldown_steps: int = 0#
warmup_steps: int = 0#
cooldown_lr: float = 0.0#
get(key, default=None)#
Parameters:

key (str)

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

static from_dict(config_class, data, strict_classname_parsing=False, ignore_extensive_attributes=True, none_to_zero_for_ints=False)#

Utility for parsing dictionaries back into a nested dataclass structure, including arbitrary classes and types.

Currently, this is tailored towards the current logging system with the “hardly” invertible to_dict.

Parameters:
  • config_class (type) – Typically a dataclass, but can be any other type as well If it is another type, the parser tries to create an object via config_class(**data) if data is a dictionary or config_class(data) else.

  • data (Any) – Typically a dictionary that contains attributes of the dataclass. Can be any other kind of data.

  • strict_classname_parsing (bool) – Parse class names strictly.

  • ignore_extensive_attributes (bool) – Ignore attributes that are not defined in the dataclass.

  • none_to_zero_for_ints (bool) – Convert None to 0 for integer types.

Returns:

An object of type config_class that contains the data as attributes.

Return type:

Any

xlstm_jax.trainer.optimizer.scheduler.build_lr_scheduler(scheduler_config)#

Build learning rate schedule from config.

By default, it supports constant, linear, cosine decay, and exponential decay, all with warmup and cooldown.

Parameters:

scheduler_config (ConfigDict) – ConfigDict for learning rate schedule.

Returns:

Learning rate schedule function.

Return type:

Callable