xlstm_jax.models.configs#

Classes#

ParallelConfig

Configuration for parallelism.

ModelConfig

Base class for model configurations.

SubModelConfig

Sub-model configuration.

Module Contents#

class xlstm_jax.models.configs.ParallelConfig#

Configuration for parallelism.

data_axis_size: int = -1#

Size of the data axis. If -1, it will be inferred by the number of available devices.

fsdp_axis_size: int = 1#

Size of the FSDP axis. If -1, it will be inferred by the number of available devices.

pipeline_axis_size: int = 1#

Size of the pipeline axis. If -1, it will be inferred by the number of available devices.

model_axis_size: int = 1#

Size of the model axis. If -1, it will be inferred by the number of available devices.

data_axis_name: str = 'dp'#

Name of the data axis.

fsdp_axis_name: str = 'fsdp'#

Name of the FSDP axis.

pipeline_axis_name: str = 'pp'#

Name of the pipeline axis.

model_axis_name: str = 'tp'#

Name of the model axis.

remat: list[str] = []#

Module names on which we apply activation checkpointing / rematerialization.

fsdp_modules: list[str] = []#

Module names on which we apply FSDP sharding.

fsdp_min_weight_size: int = 262144#

Minimum size of a parameter to be sharded with FSDP.

fsdp_gather_dtype: str | None = None#

The dtype to cast the parameters to before gathering with FSDP. If None, no casting is performed and parameters are gathered in original precision (e.g. float32).

fsdp_grad_scatter_dtype: str | None = None#

The dtype to cast the gradients to before scattering. If None, the dtype of the parameters is used.

tp_async_dense: bool = False#

Whether to use asynchronous tensor parallelism for dense layers. Default to False, as on local hardware, ppermute communication introduces large overhead.

class xlstm_jax.models.configs.ModelConfig#

Bases: xlstm_jax.configs.ConfigDict

Base class for model configurations.

model_class: callable#

Model class.

parallel: ParallelConfig#

Parallelism configuration.

model_config: xlstm_jax.configs.ConfigDict | None = None#

Model configuration.

static from_metadata(metadata_content)#

Creates a model config from a metadata file content.

Parameters:

metadata_content (str) – Content of the metadata file, currently in JSON format.

Returns:

Tuple of the model_class and the model configuration parsed into a nested ModelConfig format.

Return type:

ModelConfig

get(key, default=None)#
Parameters:

key (str)

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

static from_dict(config_class, data, strict_classname_parsing=False, ignore_extensive_attributes=True, none_to_zero_for_ints=False)#

Utility for parsing dictionaries back into a nested dataclass structure, including arbitrary classes and types.

Currently, this is tailored towards the current logging system with the “hardly” invertible to_dict.

Parameters:
  • config_class (type) – Typically a dataclass, but can be any other type as well If it is another type, the parser tries to create an object via config_class(**data) if data is a dictionary or config_class(data) else.

  • data (Any) – Typically a dictionary that contains attributes of the dataclass. Can be any other kind of data.

  • strict_classname_parsing (bool) – Parse class names strictly.

  • ignore_extensive_attributes (bool) – Ignore attributes that are not defined in the dataclass.

  • none_to_zero_for_ints (bool) – Convert None to 0 for integer types.

Returns:

An object of type config_class that contains the data as attributes.

Return type:

Any

class xlstm_jax.models.configs.SubModelConfig#

Sub-model configuration.

This class is currently a quick fix to allow for post-init style model configs, like the xlstm-clean we ported from the original xlstm codebase. Once the config system is more mature, we should remove this and all becomes a subclass of ModelConfig.

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict