Configuring Experiments with Hydra#
This project uses Hydra for managing configurations. Hydra is a framework that simplifies the process of configuring complex applications by allowing you to compose configurations dynamically. If you have never worked with Hydra it might seem daunting at first and you might want to go through the Hydra tutorial first.
We use structured configs and Python dataclasses as basis for our configuration. There is an additional Hydra tutorial for structured configs, which you should go through if you’ve never worked with those.
We will still give a very short introduction here s.t. you can get started by running your own experiment quickly.
Configuration Structure#
The Config Dataclasses#
The first thing to be aware of is that we use Python dataclasses as foundational configuration objects.
Let us look at the example of the ParallelConfig, located at xlstm_jax.models.configs.ParallelConfig:
@dataclass(kw_only=True, frozen=False)
class ParallelConfig:
"""Configuration for parallelism."""
data_axis_size: int = -1
"""Size of the data axis. If -1, it will be inferred by the number of available devices."""
fsdp_axis_size: int = 1
"""Size of the FSDP axis. If -1, it will be inferred by the number of available devices."""
pipeline_axis_size: int = 1
"""Size of the pipeline axis. If -1, it will be inferred by the number of available devices."""
model_axis_size: int = 1
"""Size of the model axis. If -1, it will be inferred by the number of available devices."""
data_axis_name: str = "dp"
"""Name of the data axis."""
fsdp_axis_name: str = "fsdp"
"""Name of the FSDP axis."""
pipeline_axis_name: str = "pp"
"""Name of the pipeline axis."""
model_axis_name: str = "tp"
"""Name of the model axis."""
remat: list[str] = field(default_factory=lambda: [])
"""Module names on which we apply activation checkpointing / rematerialization."""
fsdp_modules: list[str] = field(default_factory=lambda: [])
"""Module names on which we apply FSDP sharding."""
fsdp_min_weight_size: int = 2**18
"""Minimum size of a parameter to be sharded with FSDP."""
fsdp_gather_dtype: str | None = None
"""The dtype to cast the parameters to before gathering with FSDP. If `None`, no casting is performed and parameters
are gathered in original precision (for example `float32`)."""
fsdp_grad_scatter_dtype: str | None = None
"""The dtype to cast the gradients to before scattering. If `None`, the dtype of the parameters is used."""
tp_async_dense: bool = False
"""Whether to use asynchronous tensor parallelism for dense layers. Default to `False`, as on local hardware,
ppermute communication introduces large overhead."""
def __post_init__(self):
_allowed_fsdp_dtypes = ["float32", "bfloat16", "float16"]
if self.fsdp_gather_dtype is not None:
assert self.fsdp_gather_dtype in _allowed_fsdp_dtypes
if self.fsdp_grad_scatter_dtype is not None:
assert self.fsdp_grad_scatter_dtype in _allowed_fsdp_dtypes
Here, all attributes and default values of ParallelConfig can be seen. To tell Hydra which
attributes are available in the dataclass we have to add it to Hydra’s config store. This is done
in xlstm_jax/define_hydra_schemas.py and looks like this for the ParallelConfig:
cs.store(name="parallel_schema", group="parallel", node=ParallelConfig)
It tells Hydra that whenever I use the string “parallel_schema” in a configuration yaml file, it
uses ParallelConfig as basis and uses all attributes and default values from that class.
The Config yaml files#
To overwrite the default values from the dataclasses, Hydra
uses yaml files, which are
organized in the configs directory.
For now, the structure is as follows:
configs/: Contains all configuration files.config.yaml: The main configuration file. This is the most important file since the defaults are specified here.The submodules are in subfolders:
checkpointingdatahydraloggerlr_monitormodeloptimizerparallelprofilingschedulertrainer
Experiment files are located in the subfolder
experiment
How to Run Experiments#
The principal entry point for hydra is the script scripts/training/train_with_hydra.py.
The main_train() function in that file is decorated with:
@hydra.main(config_path="../configs", config_name="config", version_base="1.3")
def main_train(cfg: DictConfig):
...
The decorator invokes Hydra and the following happens: Hydra looks in the folder ../configs for a file named config.yaml. If it finds the file,
Hydra starts to compile the configuration from the defaults supplied in configs/config.yaml.
For example, if that file looks like this:
defaults:
- config_schema
- parallel: synthetic
- model: mLSTM120M
- _self_
# General hyperparameters. Will be put in their respective config modules once they are created.
# device. cpu or gpu
device: cpu
the following will happen:
The first line reads
- config_schema. This is related to the use of structured configs and tells Hydra which values and types are allowed to be set in this file. You can checkdefine_hydra_schemas.pyto see the definition forconfig_schema.- parallel: syntheticmeans that Hydra will now look in theparallelsubfolder for thesynthetic.yamlfile and append all parameters in that file to the compiled list of parameters under the keyparallel.- model: mLSTM120Mmeans that Hydra will now look in themodelsubfolder for themLSTM120M.yamlfile and append all parameters in that file to the compiled list of parameters under the keymodel._self_means that all parameters in this file itself (that is, theconfig.yaml) will be inserted into the compiled config.
Note that all values that come later in the defaults list overwrite potential values that have been in the configuration before.
Once Hydra has processed this file, it changes the current working directory as specified in the configs
and executes
main_train(cfg) in that directory, where cfg is the compiled configuration.
All top level parameters are in cfg, as for example, cfg.device and the
others are in their respective fields, for example, cfg.parallel.data_axis_name.
Remark It is possible to overwrite every parameter from the command line. So you can, for example, execute
PYTHONPATH=. python scripts/training/train_with_hydra.py device=gpu model.num_heads=42
to substitute cpu
from the config.yaml with gpu and to substitute 42 for whatever value for num_heads was given in mLSTM120M.yaml file. It is more convenient to use experiment files for overwriting
parameters though.
Using Experiment Files#
(see https://hydra.cc/docs/patterns/configuring_experiments/)
A more reproducible way to start experiments with Hydra is to use experiment files. In these files
you can specify all parameters and config groups that you want to change. Let us take the experiment/synthetic_experiment.yaml as an example.
The top of that file reads
# @package _global_
defaults:
- /data@data_train.ds1: synthetic
- override /parallel: synthetic
- override /model: mLSTMv1_165M
- _self_
# specify the deltas from the defaults:
task_name: slurm_tests
batch_size_per_device: 2
context_length: 128
num_train_steps: 10
lr: 1e-3
logger:
log_every_n_steps: 2
Hydra now checks the defaults list of the experiment file and replaces the respective fields from the general
config.yaml with the new values given here.
That is, no matter what was provided as parallel config in the config.yaml,
we override that with the synthetic configuration instead.
Similarly, no matter what was supplied as model in the
config.yaml, we overwrite that with the mLSTMv1_165M.
In addition, you can
overwrite any parameter with your own values. This goes for all parameters at
every level of the config hierarchy. In this example, the value in
cfg.logger.log_every_n_steps is overwritten with 2.
By using
experiment files, it becomes easy to specify and reproduce experiments.
To run this experiment, you execute
PYTHONPATH=. python scripts/training/train_with_hydra.py +experiment=synthetic_experiment
Note the + before experiment, that’s not a typo!
Type Checking of the Configurations#
One benefit of using the structured configs and dataclasses as basis for the configuration is that type checking
becomes available. This happens when compiling the cfg.
To test whether you have supplied correctly typed parameters in your experiment files, you can execute
python scripts/check_config.py +experiment={YOUR_EXPERIMENT_FILE}
This function just compiles and logs your compiled config. But it’s a good way to check if you have supplied a working experiment file before starting a job on a cluster, for example.
How to run experiments on a SLURM cluster#
To run your script on a SLURM cluster you can use the submitit_launcher, which should have been installed
for Hydra if you have used our conda env.
The
default configuration is provided in configs/hydra/launcher/slurm_launcher.yaml:
# @package _global_
defaults:
- override /hydra/launcher: submitit_slurm
hydra:
launcher:
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
timeout_min: 60
cpus_per_task: 28
gpus_per_task: 1 # each GPU must have its separate task for jax
gres: gpu:${n_gpus}
tasks_per_node: ${n_gpus}
mem_gb: ~
nodes: 1
name: ${hydra.job.name}
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
partition: compute
qos: null
comment: testing_slurmit_launcher
additional_parameters: {
"gpu-bind": "closest",
"wait-all-nodes": "1",
"time": "7-00:00:00",
"exclusive": "",
}
srun_args:
- "--kill-on-bad-exit=1"
setup:
- export NCCL_CROSS_NIC=0
- export NCCL_SOCKET_NTHREADS=16
- export NCCL_DEBUG=WARN
- export NCCL_CUMEM_ENABLE=0
- export NCCL_IB_SPLIT_DATA_ON_QPS=0
- export NCCL_IB_QPS_PER_CONNECTION=16
- export NCCL_IB_GID_INDEX=3
- export NCCL_IB_TC=41
- export NCCL_IB_SL=0
- export NCCL_IB_TIMEOUT=22
- export NCCL_NET_PLUGIN=none
- export NCCL_SOCKET_IFNAME=eth0
- export NCCL_IGNORE_CPU_AFFINITY=1
- export NCCL_IB_HCA="=mlx5_0,mlx5_1,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9,mlx5_10,mlx5_12,mlx5_13,mlx5_14,mlx5_15,mlx5_16,mlx5_17"
- export HCOLL_ENABLE_MCAST_ALL=0
- export coll_hcoll_enable=0
- export UCX_TLS=tcp
- export UCX_NET_DEVICES=eth0
- export RX_QUEUE_LEN=8192
- export IB_RX_QUEUE_LEN=8192
- export OMPI_MCA_coll=^hcoll
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/nfs-gpu/xlstm/miniforge3/envs/python_3.11_jax_0.4.34_cuda_12.6/cuda-compat
- export TOKENIZERS_PARALLELISM=0
- export JAX_COMPILATION_CACHE_DIR=/nfs-gpu/xlstm/data/jax_compilation_cache
Note that the path to the cuda-compat package and the
JAX_COMPILATION_CACHE_DIR must be adjusted
by you!
Here, you can supply (or rather overwrite in an experment file!) the things you would normally put into a SLURM run
script. To use the submitit_launcher you have to execute the following with the conda env
activated that you want to use for the experiment:
PYTHONPATH=. python scripts/training/train_with_hydra.py --multirun hydra/launcher=slurm_launcher +experiment={YOUR_EXPERIMENT_FILE}
So the only thing you have to add is --multirun hydra/launcher=slurm_launcher (and to overwrite SLURM-specific parameters as the number of nodes, for example, in your experiment file).
Note that you can also add this to your experiment file. See,
for example, the experiment synthetic_experiment_slurm.yaml,
where the lines
- override /hydra/launcher: slurm_launcher
hydra:
mode: MULTIRUN
will make Hydra use the submitit_launcher
How to Resume an Experiment?#
To resume an experiment you need to know the path to the output folder of the experiment. You then execute
python scripts/traininig/get_cli_command_to_resume_training.py --resume_from_folder=PATH_TO_RUN
If you want to use SLURM, set the flag --use_slurm.
Not that this is only required if
the original run was executed with SLURM by way of the CLI override
--multirun hydra/launcher=slurm_launcher and
not by way of experiment file. If it was specified in the experiment file, SLURM is used anyway.
If you want to use the latest checkpoint, you don’t need to supply anything but if you want to use a
specific checkpoint, use --checkpoint_step=X to use checkpoint X.
New hydra overrides can be supplied by way of --new_overrides=STRING_OF_OVERRIDES.
Example of such a string for more training steps, a different learning rate and
a different logging frequency would be
--new_overrides="num_train_steps=20000 lr=0.0001 logger.log_every_n_steps=10", that is,
the format is exactly as you would supply for CLI overrides.
Executing get_cli_command_to_resume_training will return a string of the command that you
have to execute to continue the training run.
A full example to call is
python scripts/training/get_cli_command_to_resume_training.py --resume_from_folder=PATH_TO_RUN --use_slurm --checkpoint_step=95000 --new_overrides="num_train_steps=20000 lr=0.0001 logger.log_every_n_steps=10"
What the script does is to look in the specified folder and obtain the overrides (including the experiment file) that were used to start the previous run. It then compiles a new command from these, which, for the current example would look something like this:
PYTHONPATH=. python scripts/training/resume_training_with_hydra.py -m hydra/launcher=slurm_launcher +experiment=synthetic_experiment_slurm +resume_from_folder=PATH_TO_RUN +checkpoint_step=95000 num_train_steps=20000 lr=0.0001 logger.log_every_n_steps=10
This is the actual command you have to run to resume the experiment. Keep in mind that you can still use any other CLI overrides at this point so using them in the previous step is not strictly necessary.