xlstm_jax.utils.model_param_handling.store

xlstm_jax.utils.model_param_handling.store#

Attributes#

Functions#

store_checkpoint_sharded(state_dict, checkpoint_path)

Save model parameters in sharded fashion into multiple safetensors files.

Module Contents#

xlstm_jax.utils.model_param_handling.store.LOGGER#
xlstm_jax.utils.model_param_handling.store.store_checkpoint_sharded(state_dict, checkpoint_path, max_shard_size=1 << 30, metadata=None)#

Save model parameters in sharded fashion into multiple safetensors files.

Parameters:
  • state_dict (dict[str, torch.Tensor]) – Model state dict.

  • checkpoint_path (Path) – Checkpoint Path for the model to be stored in.

  • max_shard_size (int, optional) – Maximal shard size in bytes. Defaults to 1<<30.

  • metadata (dict[str, Any], optional) – Additional metadata for the checkpoint. Defaults to {}.