xlstm_jax.trainer.base.param_utils#

Functions#

is_partitioned(x)

Check if an object is a Partitioned.

get_num_params(params)

Calculates the number of parameters in a PyTree.

tabulate_params(state[, show_weight_decay, ...])

Prints a summary of the parameters represented as table.

get_grad_norms(grads[, return_per_param])

Determine the gradient norms.

get_param_norms(params[, return_per_param])

Determine the parameter norms.

get_sharded_norm_logits(x)

Calculate the norm of a sharded parameter or gradient.

get_sharded_global_norm(x)

Calculate the norm of a sharded PyTree.

get_param_mask_fn(exclude[, include])

Returns a function that generates a mask, which can for instance be used for weight decay.

Module Contents#

xlstm_jax.trainer.base.param_utils.is_partitioned(x)#

Check if an object is a Partitioned.

Parameters that are sharded via FSDP, PP, or TP, are represented as Partitioned objects. Parameters that are replicated are represented as regular jax.Array objects. This function can be used in the context of PyTrees as is_leaf argument in a tree map to consider Partitioned objects as leaves instead of traversing them. Note that in that case, JAX Arrays of standard replicated parameters and all other normal leaves are still considered leaves.

Parameters:

x (Any) – The object to check.

Returns:

Whether the object is a Partitioned.

Return type:

bool

xlstm_jax.trainer.base.param_utils.get_num_params(params)#

Calculates the number of parameters in a PyTree.

Parameters:

params (xlstm_jax.common_types.PyTree)

Return type:

int

xlstm_jax.trainer.base.param_utils.tabulate_params(state, show_weight_decay=False, weight_decay_exclude=None, weight_decay_include=None)#

Prints a summary of the parameters represented as table.

Parameters:
Returns:

The summary table as a string.

Return type:

str

xlstm_jax.trainer.base.param_utils.get_grad_norms(grads, return_per_param=False)#

Determine the gradient norms.

Parameters:
  • grads (Any) – The gradients as a PyTree.

  • return_per_param (bool) – Whether to return the gradient norms per parameter or only the global norm.

Returns:

A dictionary containing the gradient norms.

Return type:

dict

xlstm_jax.trainer.base.param_utils.get_param_norms(params, return_per_param=False)#

Determine the parameter norms.

Parameters:
  • params (Any) – The parameters as a PyTree.

  • return_per_param (bool) – Whether to return the parameter norms per parameter or only the global norm.

Returns:

A dictionary containing the parameter norms.

Return type:

dict

xlstm_jax.trainer.base.param_utils.get_sharded_norm_logits(x)#

Calculate the norm of a sharded parameter or gradient.

Parameters:

x (jax.Array | flax.linen.Partitioned) – The parameter or gradient.

Returns:

The norm logit, i.e. the squared norm.

Return type:

jax.Array

xlstm_jax.trainer.base.param_utils.get_sharded_global_norm(x)#

Calculate the norm of a sharded PyTree.

Parameters:

x (xlstm_jax.common_types.PyTree) – The PyTree. Each leaf should be a jax.Array or nn.Partitioned.

Returns:

The global norm and the norm per leaf.

Return type:

tuple[jax.Array, PyTree]

xlstm_jax.trainer.base.param_utils.get_param_mask_fn(exclude, include=None)#

Returns a function that generates a mask, which can for instance be used for weight decay.

Parameters:
  • exclude (Sequence[str]) – List of strings to exclude.

  • include (Sequence[str]) – List of strings to include. If None, all parameters except those in exclude are included.

Returns:

Function that generates a mask.

Return type:

Callable[[PyTree], PyTree]