xlstm_jax.utils.model_param_handling.convert_checkpoint

xlstm_jax.utils.model_param_handling.convert_checkpoint#

Functions#

convert_orbax_checkpoint_to_torch_state_dict(orbax_pytree)

Convert orbax pytree params to a (flat) torch state dict.

Module Contents#

xlstm_jax.utils.model_param_handling.convert_checkpoint.convert_orbax_checkpoint_to_torch_state_dict(orbax_pytree, split_blocks=True, blocks_layer_name='blocks')#

Convert orbax pytree params to a (flat) torch state dict.

Parameters:
  • orbax_pytree (dict[str, Any]) – The orbax pytree params.

  • split_blocks (bool) – Whether to split the parameters of the blocks into individual tensors. Defaults to True. Jax stores the weight tensors/arrays of the all blocks in a single tensor/array with the first dimension being the number of blocks. PyTorch expects the weights of each block to be a separate tensor/array. This is why we split the weights of each block into separate tensors/arrays.

  • blocks_layer_name (str) – The blocks layer name to split parameters by. Defaults to “blocks”.

Returns:

The torch state dict.

Return type:

dict[str, torch.Tensor]