xlstm_jax.utils.model_param_handling.store#
Attributes#
Functions#
|
Save model parameters in sharded fashion into multiple safetensors files. |
Module Contents#
- xlstm_jax.utils.model_param_handling.store.LOGGER#
- xlstm_jax.utils.model_param_handling.store.store_checkpoint_sharded(state_dict, checkpoint_path, max_shard_size=1 << 30, metadata=None)#
Save model parameters in sharded fashion into multiple safetensors files.
- Parameters:
state_dict (dict[str, torch.Tensor]) – Model state dict.
checkpoint_path (Path) – Checkpoint Path for the model to be stored in.
max_shard_size (int, optional) – Maximal shard size in bytes. Defaults to 1<<30.
metadata (dict[str, Any], optional) – Additional metadata for the checkpoint. Defaults to {}.