xlstm_jax.trainer.base.param_utils#
Functions#
Check if an object is a Partitioned. |
|
|
Calculates the number of parameters in a PyTree. |
|
Prints a summary of the parameters represented as table. |
|
Determine the gradient norms. |
|
Determine the parameter norms. |
Calculate the norm of a sharded parameter or gradient. |
|
Calculate the norm of a sharded PyTree. |
|
|
Returns a function that generates a mask, which can for instance be used for weight decay. |
Module Contents#
- xlstm_jax.trainer.base.param_utils.is_partitioned(x)#
Check if an object is a Partitioned.
Parameters that are sharded via FSDP, PP, or TP, are represented as Partitioned objects. Parameters that are replicated are represented as regular jax.Array objects. This function can be used in the context of PyTrees as is_leaf argument in a tree map to consider Partitioned objects as leaves instead of traversing them. Note that in that case, JAX Arrays of standard replicated parameters and all other normal leaves are still considered leaves.
- Parameters:
x (Any) – The object to check.
- Returns:
Whether the object is a Partitioned.
- Return type:
- xlstm_jax.trainer.base.param_utils.get_num_params(params)#
Calculates the number of parameters in a PyTree.
- Parameters:
params (xlstm_jax.common_types.PyTree)
- Return type:
- xlstm_jax.trainer.base.param_utils.tabulate_params(state, show_weight_decay=False, weight_decay_exclude=None, weight_decay_include=None)#
Prints a summary of the parameters represented as table.
- Parameters:
state (xlstm_jax.common_types.TrainState | dict[str, Any]) – The TrainState or the parameters as a dictionary.
show_weight_decay (bool) – Whether to show the weight decay mask.
weight_decay_exclude (collections.abc.Sequence[re.Pattern] | None) – List of regex patterns to exclude from weight decay. See optimizer config for more information.
weight_decay_include (collections.abc.Sequence[re.Pattern] | None) – List of regex patterns to include in weight decay. See optimizer config for more information.
- Returns:
The summary table as a string.
- Return type:
- xlstm_jax.trainer.base.param_utils.get_grad_norms(grads, return_per_param=False)#
Determine the gradient norms.
- xlstm_jax.trainer.base.param_utils.get_param_norms(params, return_per_param=False)#
Determine the parameter norms.
- xlstm_jax.trainer.base.param_utils.get_sharded_norm_logits(x)#
Calculate the norm of a sharded parameter or gradient.
- xlstm_jax.trainer.base.param_utils.get_sharded_global_norm(x)#
Calculate the norm of a sharded PyTree.
- xlstm_jax.trainer.base.param_utils.get_param_mask_fn(exclude, include=None)#
Returns a function that generates a mask, which can for instance be used for weight decay.