xlstm_jax.define_hydra_schemas

Contents

xlstm_jax.define_hydra_schemas#

Register all config dataclasses in the project to Hydra’s ConfigStore and define hydra config schemas.

Classes#

CombinedModelConfig

This class is a flat config that combines several sub-configs.

BaseLoggerConfig

DataEvalConfig

Supports maximum of 5 data evaluation configurations.

DataTrainConfig

Supports maximum of 10 data training configurations.

Config

The base config class.

Functions#

Module Contents#

class xlstm_jax.define_hydra_schemas.CombinedModelConfig#

This class is a flat config that combines several sub-configs.

name: str#
vocab_size: int#
embedding_dim: int#
num_blocks: int#
context_length: int#
tie_weights: bool#
add_embedding_dropout: bool#
add_post_blocks_norm: bool#
scan_blocks: bool#
norm_eps: float#
norm_type: str#
init_distribution_embed: str#
logits_soft_cap: float#
lm_head_dtype: str#
dtype: str#
add_post_norm: bool#
layer_type: str#
num_heads: int#
output_init_fn: str#
init_distribution: str#
qk_dim_factor: float#
v_dim_factor: float#
gate_dtype: str#
backend: str#
backend_name: str#
igate_bias_init_range: float#
add_qk_norm: bool#
cell_norm_type: str#
cell_norm_type_v1: str#
cell_norm_eps: float#
gate_soft_cap: float#
reset_at_document_boundaries: bool#
proj_factor: float#
act_fn: str#
ff_type: str#
ff_dtype: str#
head_dim: int#
attention_backend: str#
theta: float#
class xlstm_jax.define_hydra_schemas.BaseLoggerConfig#

Bases: xlstm_jax.trainer.logger.base_logger.LoggerConfig

loggers_to_use: list[str]#
file_logger_log_dir: str#
file_logger_config_format: str#
tb_log_dir: str#
tb_flush_secs: int#
wb_project: str#
wb_entity: str#
wb_name: str#
wb_tags: list[str]#
class xlstm_jax.define_hydra_schemas.DataEvalConfig#

Supports maximum of 5 data evaluation configurations.

ds1: xlstm_jax.dataset.configs.DataConfig | None = None#
ds2: xlstm_jax.dataset.configs.DataConfig | None = None#
ds3: xlstm_jax.dataset.configs.DataConfig | None = None#
ds4: xlstm_jax.dataset.configs.DataConfig | None = None#
ds5: xlstm_jax.dataset.configs.DataConfig | None = None#
class xlstm_jax.define_hydra_schemas.DataTrainConfig#

Supports maximum of 10 data training configurations.

ds1: xlstm_jax.dataset.configs.DataConfig | None = None#
weight1: float = 1.0#
ds2: xlstm_jax.dataset.configs.DataConfig | None = None#
weight2: float = 1.0#
ds3: xlstm_jax.dataset.configs.DataConfig | None = None#
weight3: float = 1.0#
ds4: xlstm_jax.dataset.configs.DataConfig | None = None#
weight4: float = 1.0#
ds5: xlstm_jax.dataset.configs.DataConfig | None = None#
weight5: float = 1.0#
ds6: xlstm_jax.dataset.configs.DataConfig | None = None#
weight6: float = 1.0#
ds7: xlstm_jax.dataset.configs.DataConfig | None = None#
weight7: float = 1.0#
ds8: xlstm_jax.dataset.configs.DataConfig | None = None#
weight8: float = 1.0#
ds9: xlstm_jax.dataset.configs.DataConfig | None = None#
weight9: float = 1.0#
ds10: xlstm_jax.dataset.configs.DataConfig | None = None#
weight10: float = 1.0#
class xlstm_jax.define_hydra_schemas.Config#

The base config class.

parallel: xlstm_jax.models.configs.ParallelConfig#
model: CombinedModelConfig#
scheduler: xlstm_jax.trainer.optimizer.scheduler.SchedulerConfig#
optimizer: xlstm_jax.trainer.optimizer.optimizer.OptimizerConfig#
checkpointing: xlstm_jax.trainer.callbacks.checkpointing.ModelCheckpointConfig#
lr_monitor: xlstm_jax.trainer.callbacks.lr_monitor.LearningRateMonitorConfig#
profiling: xlstm_jax.trainer.callbacks.profiler.JaxProfilerConfig#
logger: xlstm_jax.trainer.logger.base_logger.LoggerConfig#
trainer: xlstm_jax.trainer.base.trainer.TrainerConfig#
device: str#
device_count: int#
n_gpus: int#
batch_size_per_device: int#
global_batch_size: int#
lr: float#
context_length: int#
num_epochs: int#
num_train_steps: int#
log_path: pathlib.Path#
base_dir: pathlib.Path#
task_name: str#
logging_name: str#
data_train: DataTrainConfig | None = None#
data_eval: DataEvalConfig | None = None#
xlstm_jax.define_hydra_schemas.register_configs()#
Return type:

None