xlstm_jax.models.xlstm_parallel.training#

Functions#

print_metrics(metrics[, title])

Prints metrics with an optional title.

get_num_params(state)

Calculate the number of parameters in the model.

loss_fn(params, apply_fn, batch, rng, config)

train_step(state, metrics, batch, config[, ...])

get_train_step_fn(state, batch, mesh, config[, ...])

init_xlstm(config, mesh, rng, input_array, optimizer)

flatten_dict(d)

Flattens a nested dictionary.

tabulate_params(state)

Prints a summary of the parameters represented as table.

Module Contents#

xlstm_jax.models.xlstm_parallel.training.print_metrics(metrics, title=None)#

Prints metrics with an optional title.

Parameters:
  • metrics (xlstm_jax.common_types.Metrics) – A dictionary with metric names as keys and a tuple of (sum, count) as values.

  • title (str | None) – An optional title for the metrics.

Return type:

None

xlstm_jax.models.xlstm_parallel.training.get_num_params(state)#

Calculate the number of parameters in the model.

Parameters:

state (xlstm_jax.common_types.TrainState) – The current training state.

Returns:

The number of parameters in the model.

Return type:

int

xlstm_jax.models.xlstm_parallel.training.loss_fn(params, apply_fn, batch, rng, config)#
Parameters:
Return type:

tuple[jax.Array, tuple[dict[str, Any], xlstm_jax.common_types.PyTree]]

xlstm_jax.models.xlstm_parallel.training.train_step(state, metrics, batch, config, gradient_accumulate_steps=1)#
Parameters:
Return type:

tuple[xlstm_jax.common_types.TrainState, xlstm_jax.common_types.Metrics]

xlstm_jax.models.xlstm_parallel.training.get_train_step_fn(state, batch, mesh, config, gradient_accumulate_steps=1)#
Parameters:
Return type:

tuple[callable, xlstm_jax.common_types.PyTree]

xlstm_jax.models.xlstm_parallel.training.init_xlstm(config, mesh, rng, input_array, optimizer)#
Parameters:
xlstm_jax.models.xlstm_parallel.training.flatten_dict(d)#

Flattens a nested dictionary.

Parameters:

d (dict)

Return type:

dict

xlstm_jax.models.xlstm_parallel.training.tabulate_params(state)#

Prints a summary of the parameters represented as table.

Parameters:

state (xlstm_jax.common_types.TrainState) – The current training state.

Return type:

str