xlstm_jax.trainer.optimizer.ademamix#

Adapted from Apple’s official AdeMAMix implementation: apple/ml-ademamix

TODO: Check license and add it here.

Classes#

ScaleByAdemamixState

State for the AdEMAMix algorithm.

Functions#

alpha_scheduler(alpha[, alpha_start, warmup])

Linear scheduler for the mixing coefficient alpha in AdEMAMix.

beta3_scheduler(beta_end[, beta_start, warmup])

Linear scheduler for the EMA parameter beta3 in AdEMAMix.

ademamix(lr[, b1, b2, b3, alpha, b3_scheduler, ...])

AdEMAMix.

scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, ...)

Scales updates by the AdEMAMix algorithm.

tree_cast(tree, dtype)

Cast tree to given dtype, skip if None.

tree_zeros_like(tree[, dtype])

Creates an all-zeros tree with the same structure.

tree_update_moment(updates, moments, decay, order)

Compute the exponential moving average of the order-th moment.

tree_update_moment_per_elem_norm(updates, moments, ...)

Compute the EMA of the order-th moment of the element-wise norm.

tree_bias_correction(moment, decay, count)

Performs bias correction. It becomes a no-op as count goes to infinity.

Module Contents#

xlstm_jax.trainer.optimizer.ademamix.alpha_scheduler(alpha, alpha_start=0.0, warmup=0)#

Linear scheduler for the mixing coefficient alpha in AdEMAMix.

Parameters:
  • alpha (float) – Final value of alpha.

  • alpha_start (float) – Initial value of alpha.

  • warmup (int) – Number of steps for the warmup phase. Often set equal to the number of training steps.

Returns:

A scheduler function that takes a step and returns the value of alpha.

Return type:

optax.Schedule

xlstm_jax.trainer.optimizer.ademamix.beta3_scheduler(beta_end, beta_start=0.0, warmup=0)#

Linear scheduler for the EMA parameter beta3 in AdEMAMix.

Parameters:
  • beta_end (float) – Final value of beta3.

  • beta_start (float) – Initial value of beta3. Often set equal to beta1.

  • warmup (int) – Number of steps for the warmup phase. Often set equal to the number of training steps.

Returns:

A scheduler function that takes a step and returns the value of beta3.

Return type:

optax.Schedule

class xlstm_jax.trainer.optimizer.ademamix.ScaleByAdemamixState#

Bases: NamedTuple

State for the AdEMAMix algorithm.

count: chex.Array#

Step counter for the first momentum and adaptive learning rate.

count_m2: chex.Array#

Step counter for the slower momentum.

m1: optax._src.base.Updates#

Fast EMA.

m2: optax._src.base.Updates#

Slow EMA.

nu: optax._src.base.Updates#

Second moment estimate.

xlstm_jax.trainer.optimizer.ademamix.ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, eps=1e-08, eps_root=0.0, weight_decay=0.0, mu_dtype=None, mask=None)#

AdEMAMix.

Parameters:
  • lr (float | optax.Schedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the fast EMA.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • b3 (float) – Exponential decay rate to track the slow EMA.

  • alpha (float) – Mixing coeficient use for the linear combination of the fast and slow EMAs.

  • b3_scheduler (optax.Schedule | None) – an optional scheduler function, given a timestep, returns the value of b3. Use beta3_scheduler(b3,b1,T_b3) to follow the AdEMAMix paper.

  • alpha_scheduler (optax.Schedule | None) – an optional scheduler function, given a timestep, returns the value of alpha. Use alpha_scheduler(alpha,0,T_alpha) to follow the AdEMAMix paper.

  • eps (float) – A small constant applied to denominator outside the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype (jax.numpy.dtype | None) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask (collections.abc.Callable[[xlstm_jax.common_types.PyTree], xlstm_jax.common_types.PyTree] | None) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Returns:

The corresponding GradientTransformation.

Return type:

optax.GradientTransformation

xlstm_jax.trainer.optimizer.ademamix.scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps=1e-08, eps_root=0.0, mu_dtype=None)#

Scales updates by the AdEMAMix algorithm.

Parameters:
  • b1 (float) – Exponential decay rate to track the fast EMA.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • b3 (float) – Exponential decay rate to track the slow EMA.

  • alpha (float) – Mixing coeficient use for the linear combination of the fast and slow EMAs.

  • b3_scheduler (optax.Schedule | None) – an optional scheduler function, given a timestep, returns the value of b3. Use beta3_scheduler(b3,b1,T_b3) to follow the AdEMAMix paper.

  • alpha_scheduler (optax.Schedule | None) – an optional scheduler function, given a timestep, returns the value of alpha. Use alpha_scheduler(alpha,0,T_alpha) to follow the AdEMAMix paper.

  • eps (float) – A small constant applied to denominator outside the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype (jax.numpy.dtype | None) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Returns:

The corresponding GradientTransformation.

Return type:

optax.GradientTransformation

xlstm_jax.trainer.optimizer.ademamix.tree_cast(tree, dtype)#

Cast tree to given dtype, skip if None.

Parameters:
Return type:

xlstm_jax.common_types.PyTree

xlstm_jax.trainer.optimizer.ademamix.tree_zeros_like(tree, dtype=None)#

Creates an all-zeros tree with the same structure.

Parameters:
  • tree (xlstm_jax.common_types.PyTree) – pytree.

  • dtype (jax.numpy.dtype | None) – optional dtype to use for the tree of zeros.

Returns:

an all-zeros tree with the same structure as tree.

xlstm_jax.trainer.optimizer.ademamix.tree_update_moment(updates, moments, decay, order)#

Compute the exponential moving average of the order-th moment.

Parameters:
  • updates (xlstm_jax.common_types.PyTree) – Gradients.

  • moments (xlstm_jax.common_types.PyTree) – Moments.

  • decay (float | jax.Array) – Decay rate.

  • order (float | jax.Array) – Order of the moment.

Returns:

The updated moments.

Return type:

xlstm_jax.common_types.PyTree

xlstm_jax.trainer.optimizer.ademamix.tree_update_moment_per_elem_norm(updates, moments, decay, order)#

Compute the EMA of the order-th moment of the element-wise norm.

Parameters:
  • updates (xlstm_jax.common_types.PyTree) – Gradients.

  • moments (xlstm_jax.common_types.PyTree) – Moments.

  • decay (float | jax.Array) – Decay rate.

  • order (float | jax.Array) – Order of the moment.

Returns:

The updated moments.

Return type:

xlstm_jax.common_types.PyTree

xlstm_jax.trainer.optimizer.ademamix.tree_bias_correction(moment, decay, count)#

Performs bias correction. It becomes a no-op as count goes to infinity.

Parameters:
Returns:

The bias-corrected moments.

Return type:

xlstm_jax.common_types.PyTree