xlstm_jax.models.configs#
Classes#
Configuration for parallelism. |
|
Base class for model configurations. |
|
Sub-model configuration. |
Module Contents#
- class xlstm_jax.models.configs.ParallelConfig#
Configuration for parallelism.
- data_axis_size: int = -1#
Size of the data axis. If -1, it will be inferred by the number of available devices.
- fsdp_axis_size: int = 1#
Size of the FSDP axis. If -1, it will be inferred by the number of available devices.
- pipeline_axis_size: int = 1#
Size of the pipeline axis. If -1, it will be inferred by the number of available devices.
- model_axis_size: int = 1#
Size of the model axis. If -1, it will be inferred by the number of available devices.
- fsdp_gather_dtype: str | None = None#
The dtype to cast the parameters to before gathering with FSDP. If None, no casting is performed and parameters are gathered in original precision (e.g. float32).
- class xlstm_jax.models.configs.ModelConfig#
Bases:
xlstm_jax.configs.ConfigDictBase class for model configurations.
- model_class: callable#
Model class.
- parallel: ParallelConfig#
Parallelism configuration.
- model_config: xlstm_jax.configs.ConfigDict | None = None#
Model configuration.
- static from_metadata(metadata_content)#
Creates a model config from a metadata file content.
- Parameters:
metadata_content (str) – Content of the metadata file, currently in JSON format.
- Returns:
Tuple of the model_class and the model configuration parsed into a nested ModelConfig format.
- Return type:
- to_dict()#
Converts the config to a dictionary.
Helpful for saving to disk or logging.
- static from_dict(config_class, data, strict_classname_parsing=False, ignore_extensive_attributes=True, none_to_zero_for_ints=False)#
Utility for parsing dictionaries back into a nested dataclass structure, including arbitrary classes and types.
Currently, this is tailored towards the current logging system with the “hardly” invertible to_dict.
- Parameters:
config_class (type) – Typically a dataclass, but can be any other type as well If it is another type, the parser tries to create an object via config_class(**data) if data is a dictionary or config_class(data) else.
data (Any) – Typically a dictionary that contains attributes of the dataclass. Can be any other kind of data.
strict_classname_parsing (bool) – Parse class names strictly.
ignore_extensive_attributes (bool) – Ignore attributes that are not defined in the dataclass.
none_to_zero_for_ints (bool) – Convert None to 0 for integer types.
- Returns:
An object of type config_class that contains the data as attributes.
- Return type:
Any
- class xlstm_jax.models.configs.SubModelConfig#
Sub-model configuration.
This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.