xlstm_jax.models.xlstm_parallel.training#
Functions#
|
Prints metrics with an optional title. |
|
Calculate the number of parameters in the model. |
|
|
|
|
|
|
|
|
|
Flattens a nested dictionary. |
|
Prints a summary of the parameters represented as table. |
Module Contents#
- xlstm_jax.models.xlstm_parallel.training.print_metrics(metrics, title=None)#
Prints metrics with an optional title.
- Parameters:
metrics (xlstm_jax.common_types.Metrics) – A dictionary with metric names as keys and a tuple of (sum, count) as values.
title (str | None) – An optional title for the metrics.
- Return type:
None
- xlstm_jax.models.xlstm_parallel.training.get_num_params(state)#
Calculate the number of parameters in the model.
- Parameters:
state (xlstm_jax.common_types.TrainState) – The current training state.
- Returns:
The number of parameters in the model.
- Return type:
- xlstm_jax.models.xlstm_parallel.training.loss_fn(params, apply_fn, batch, rng, config)#
- xlstm_jax.models.xlstm_parallel.training.train_step(state, metrics, batch, config, gradient_accumulate_steps=1)#
- Parameters:
metrics (xlstm_jax.common_types.Metrics | None)
batch (xlstm_jax.distributed.single_gpu.Batch)
gradient_accumulate_steps (int)
- Return type:
tuple[xlstm_jax.common_types.TrainState, xlstm_jax.common_types.Metrics]
- xlstm_jax.models.xlstm_parallel.training.get_train_step_fn(state, batch, mesh, config, gradient_accumulate_steps=1)#
- Parameters:
batch (xlstm_jax.distributed.single_gpu.Batch)
mesh (jax.sharding.Mesh)
gradient_accumulate_steps (int)
- Return type:
tuple[callable, xlstm_jax.common_types.PyTree]
- xlstm_jax.models.xlstm_parallel.training.init_xlstm(config, mesh, rng, input_array, optimizer)#
- Parameters:
config (xlstm_jax.models.xlstm_parallel.xlstm_lm_model.xLSTMLMModelConfig)
mesh (jax.sharding.Mesh)
rng (jax.Array)
input_array (jax.Array)
optimizer (callable)
- xlstm_jax.models.xlstm_parallel.training.flatten_dict(d)#
Flattens a nested dictionary.
- xlstm_jax.models.xlstm_parallel.training.tabulate_params(state)#
Prints a summary of the parameters represented as table.
- Parameters:
state (xlstm_jax.common_types.TrainState) – The current training state.
- Return type: