xlstm_jax.common_types#
Attributes#
Mode for logging. Describes how to aggregate metrics over steps. |
|
Classes#
TrainState with additional mutable variables and RNG. |
Module Contents#
- xlstm_jax.common_types.PyTree#
- xlstm_jax.common_types.Parameter#
- xlstm_jax.common_types.PRNGKeyArray#
- xlstm_jax.common_types.LogMode#
Mode for logging. Describes how to aggregate metrics over steps.
mean: Mean of the metric.
mean_nopostfix: Mean of the metric without adding a mean postfix to the key.
single: Single value of the metric, i.e. only tracks the last value.
max: Maximum value of the metric.
std: Standard deviation of the metric.
single_noreduce: Concatenate the metrics of multiple values.
single_noreduce_wcount: Concatenate the metrics and counts of multiple values.
- xlstm_jax.common_types.ImmutableMetricElement#
- xlstm_jax.common_types.ImmutableMetrics#
- xlstm_jax.common_types.MutableMetricElement#
- xlstm_jax.common_types.MutableMetrics#
- xlstm_jax.common_types.StepMetricsElement#
- xlstm_jax.common_types.StepMetrics#
- xlstm_jax.common_types.MetricElement#
- xlstm_jax.common_types.Metrics#
- xlstm_jax.common_types.HostMetricElement#
- xlstm_jax.common_types.HostMetrics#
- class xlstm_jax.common_types.TrainState#
Bases:
flax.training.train_state.TrainStateTrainState with additional mutable variables and RNG.
- mutable_variables: Any = None#