xlstm_jax.utils.model_param_handling.load#

Attributes#

Functions#

load_model_params_and_config_from_checkpoint(...[, ...])

Load model parameters and config from a jax checkpoint.

Module Contents#

xlstm_jax.utils.model_param_handling.load.LOGGER#
xlstm_jax.utils.model_param_handling.load.load_model_params_and_config_from_checkpoint(checkpoint_path, return_config_as_dataclass=False)#

Load model parameters and config from a jax checkpoint.

Parameters:
  • checkpoint_path (str | Path) – The path to the checkpoint file.

  • return_config_as_dataclass (bool)

Returns:

The model parameters and the model config.

Return type:

tuple[dict[str, Any], dict[str, Any]]