xlstm_jax.utils.model_param_handling.convert_state_dict#
Attributes#
Functions#
|
Finds the matching parameter key for the target state dict. |
|
Creates a full state dict key mapping from the source state dict to the target state dict. |
|
Move the params of a model from one state dict to another state dict. |
Move the params of a model from one state dict to another state dict. |
|
|
Converts the keys of the state dict according to the key mapping. |
|
Applies weight transforms to the weights of the state dict. |
Module Contents#
- xlstm_jax.utils.model_param_handling.convert_state_dict.LOGGER#
- xlstm_jax.utils.model_param_handling.convert_state_dict.find_parameter_match_key(from_key, to_keys, match_dict)#
Finds the matching parameter key for the target state dict.
- Parameters:
- Returns:
The target state dict key that matches the source state dict key.
- Return type:
- xlstm_jax.utils.model_param_handling.convert_state_dict.create_full_state_dict_key_mapping(from_state_dict, to_state_dict, match_dict)#
Creates a full state dict key mapping from the source state dict to the target state dict.
- Parameters:
- Returns:
The full state dict key mapping from the source state dict to the target state dict.
- Return type:
- xlstm_jax.utils.model_param_handling.convert_state_dict.move_state_dict_params_(from_state_dict, to_state_dict, match_dict)#
Move the params of a model from one state dict to another state dict. Modifies the to_state_dict in place.
- Parameters:
- Returns:
The target (modified to_state_dict) state dict with the converted parameters.
- Return type:
- xlstm_jax.utils.model_param_handling.convert_state_dict.move_safetensors_state_dict_params_(from_state_dict_path, to_state_dict, match_dict)#
Move the params of a model from one state dict to another state dict. It loads the from_state_dict from a file on-the-fly. This means only the to_state_dict is in memory. Modifies the to_state_dict in place.
- Parameters:
from_state_dict (Path) – The path to the source state dict.
match_dict (dict[str, str]) – The dict that maps the source state dict keys to the target state dict keys. Should contain unique substrings of the source state dict keys as keys and the corresponding target state dict keys substrings as values.
- Returns:
The target (modified to_state_dict) state dict with the converted parameters.
- Return type:
- xlstm_jax.utils.model_param_handling.convert_state_dict.convert_state_dict_keys_(state_dict, full_key_mapping)#
Converts the keys of the state dict according to the key mapping.
- xlstm_jax.utils.model_param_handling.convert_state_dict.apply_weight_transforms_(state_dict, apply_transforms_to_keys)#
Applies weight transforms to the weights of the state dict.
- There are currently these transforms supported:
“transpose”: Transposes the weight tensor. Accepts only 2D tensors.
- “squeeze-XXX”: Squeezes the XXX dimension of the weight tensor.
If XXX is not given, squeezes all dimensions of size 1.
“flatten”: Flattens the weight tensor.
If possible the transforms are applied in-place on tensors. Also the state_dict is modified in-place.
- Parameters:
- Returns:
The state dict with the transformed weights
- Return type: