xlstm_jax.utils.model_param_handling.handle_mlstm_simple#

Attributes#

Functions#

create_mlstm_simple_config_from_jax_config(...[, ...])

apply_mlstm_param_reshapes(state_dict)

move_mlstm_jax_state_dict_into_torch_state_dict(...[, ...])

Move the mLSTM jax model state dict into the torch model.

pipeline_convert_mlstm_checkpoint_jax_to_torch_simple(...)

store_mlstm_simple_to_checkpoint(mlstm_model, ...[, ...])

Stores a mLSTM simple model into a checkpoint directory, using either the

convert_mlstm_checkpoint_jax_to_torch_simple(...[, ...])

Convert a jax mLSTM checkpoint to a torch mLSTM checkpoint.

Module Contents#

xlstm_jax.utils.model_param_handling.handle_mlstm_simple.LOGGER#
xlstm_jax.utils.model_param_handling.handle_mlstm_simple.create_mlstm_simple_config_from_jax_config(model_config_jax, overrides=None)#
Parameters:
Return type:

mlstm_simple_torch.mlstm_simple.model.mLSTMConfig

xlstm_jax.utils.model_param_handling.handle_mlstm_simple.apply_mlstm_param_reshapes(state_dict)#
Parameters:

state_dict (dict[str, Any])

Return type:

dict[str, Any]

xlstm_jax.utils.model_param_handling.handle_mlstm_simple.move_mlstm_jax_state_dict_into_torch_state_dict(model_state_dict_torch, model_state_dict_jax_path=None, model_state_dict_jax=None)#

Move the mLSTM jax model state dict into the torch model.

Either loads the jax model state dict from the model_state_dict_jax_path or uses the provided model_state_dict_jax.

Parameters:
  • model_torch (dict[str, Any]) – The torch model.

  • model_state_dict_jax_path (Path) – The path to the jax model state dict. Defaults to None.

  • model_state_dict_jax (dict[str, Any]) – The jax model state dict. Defaults to None.

  • model_state_dict_torch (dict[str, Any])

Returns:

The torch model with the jax model state dict loaded.

Return type:

mLSTM

xlstm_jax.utils.model_param_handling.handle_mlstm_simple.pipeline_convert_mlstm_checkpoint_jax_to_torch_simple(jax_orbax_model_checkpoint, jax_model_config, torch_model_config_overrides=None)#
Parameters:
  • jax_orbax_model_checkpoint (dict[str, Any])

  • jax_model_config (dict[str, Any])

  • torch_model_config_overrides (dict[str, Any])

Return type:

mlstm_simple_torch.mlstm_simple.model.mLSTM

xlstm_jax.utils.model_param_handling.handle_mlstm_simple.store_mlstm_simple_to_checkpoint(mlstm_model, store_torch_model_checkpoint_path, checkpoint_type='plain', max_shard_size=0)#

Stores a mLSTM simple model into a checkpoint directory, using either the huggingface or plain format.

Parameters:
  • mlstm_model (mLSTM) – The mLSTM simple model.

  • into. (store_torch_model_checkpoint_path; Torch checkpoint path to store)

  • checkpoint_type (Literal['plain', 'huggingface']) – Type of model checkpoint, either ‘plain’ or ‘huggingface’.

  • max_shard_size (int) – Largest size of a checkpoint model shard. Zero means no sharding.

  • store_torch_model_checkpoint_path (pathlib.Path)

xlstm_jax.utils.model_param_handling.handle_mlstm_simple.convert_mlstm_checkpoint_jax_to_torch_simple(load_jax_model_checkpoint_path, store_torch_model_checkpoint_path, checkpoint_type='plain', max_shard_size=0)#

Convert a jax mLSTM checkpoint to a torch mLSTM checkpoint.

Loads the jax mLSTM checkpoint, creates a torch mLSTM model, and moves the jax checkpoint parameters into the torch model.

The checkpoint for the torch model is then saved to the store_torch_model_checkpoint_path.

The torch checkpoint is a directory containing the model params as .safetensors file(s) and a config.yaml file.

Parameters:
  • load_jax_model_checkpoint_path (pathlib.Path) – Orbax checkpoint path.

  • into. (store_torch_model_checkpoint_path; Torch checkpoint path to store)

  • checkpoint_type (Literal['plain', 'huggingface']) – Type of model checkpoint, either ‘plain’ or ‘huggingface’.

  • max_shard_size (int) – Largest size of a checkpoint model shard. Zero means no sharding

  • store_torch_model_checkpoint_path (pathlib.Path)

Return type:

None