xlstm_jax.utils.model_param_handling.handle_mlstm_simple#
Attributes#
Functions#
|
|
|
|
|
Move the mLSTM jax model state dict into the torch model. |
|
Stores a mLSTM simple model into a checkpoint directory, using either the |
|
Convert a jax mLSTM checkpoint to a torch mLSTM checkpoint. |
Module Contents#
- xlstm_jax.utils.model_param_handling.handle_mlstm_simple.LOGGER#
- xlstm_jax.utils.model_param_handling.handle_mlstm_simple.create_mlstm_simple_config_from_jax_config(model_config_jax, overrides=None)#
- xlstm_jax.utils.model_param_handling.handle_mlstm_simple.apply_mlstm_param_reshapes(state_dict)#
- xlstm_jax.utils.model_param_handling.handle_mlstm_simple.move_mlstm_jax_state_dict_into_torch_state_dict(model_state_dict_torch, model_state_dict_jax_path=None, model_state_dict_jax=None)#
Move the mLSTM jax model state dict into the torch model.
Either loads the jax model state dict from the model_state_dict_jax_path or uses the provided model_state_dict_jax.
- Parameters:
- Returns:
The torch model with the jax model state dict loaded.
- Return type:
mLSTM
- xlstm_jax.utils.model_param_handling.handle_mlstm_simple.pipeline_convert_mlstm_checkpoint_jax_to_torch_simple(jax_orbax_model_checkpoint, jax_model_config, torch_model_config_overrides=None)#
- xlstm_jax.utils.model_param_handling.handle_mlstm_simple.store_mlstm_simple_to_checkpoint(mlstm_model, store_torch_model_checkpoint_path, checkpoint_type='plain', max_shard_size=0)#
Stores a mLSTM simple model into a checkpoint directory, using either the huggingface or plain format.
- Parameters:
mlstm_model (mLSTM) – The mLSTM simple model.
into. (store_torch_model_checkpoint_path; Torch checkpoint path to store)
checkpoint_type (Literal['plain', 'huggingface']) – Type of model checkpoint, either ‘plain’ or ‘huggingface’.
max_shard_size (int) – Largest size of a checkpoint model shard. Zero means no sharding.
store_torch_model_checkpoint_path (pathlib.Path)
- xlstm_jax.utils.model_param_handling.handle_mlstm_simple.convert_mlstm_checkpoint_jax_to_torch_simple(load_jax_model_checkpoint_path, store_torch_model_checkpoint_path, checkpoint_type='plain', max_shard_size=0)#
Convert a jax mLSTM checkpoint to a torch mLSTM checkpoint.
Loads the jax mLSTM checkpoint, creates a torch mLSTM model, and moves the jax checkpoint parameters into the torch model.
The checkpoint for the torch model is then saved to the store_torch_model_checkpoint_path.
The torch checkpoint is a directory containing the model params as .safetensors file(s) and a config.yaml file.
- Parameters:
load_jax_model_checkpoint_path (pathlib.Path) – Orbax checkpoint path.
into. (store_torch_model_checkpoint_path; Torch checkpoint path to store)
checkpoint_type (Literal['plain', 'huggingface']) – Type of model checkpoint, either ‘plain’ or ‘huggingface’.
max_shard_size (int) – Largest size of a checkpoint model shard. Zero means no sharding
store_torch_model_checkpoint_path (pathlib.Path)
- Return type:
None