xlstm_jax.models.xlstm_parallel.components.normalization

xlstm_jax.models.xlstm_parallel.components.normalization#

Attributes#

Functions#

NormLayer([weight, bias, eps, dtype, norm_type, ...])

Create a norm layer.

MultiHeadNormLayer([weight, bias, eps, dtype, axis, ...])

Create a multi-head norm layer.

resolve_norm(norm_type[, weight, bias, eps, dtype])

Resolve the norm layer based on the norm type.

Module Contents#

xlstm_jax.models.xlstm_parallel.components.normalization.LOGGER#
xlstm_jax.models.xlstm_parallel.components.normalization.NormType#
xlstm_jax.models.xlstm_parallel.components.normalization.NormLayer(weight=True, bias=False, eps=1e-05, dtype=jnp.float32, norm_type='layernorm', model_axis_name=None, **kwargs)#

Create a norm layer.

Parameters:
  • weight (bool) – Whether to use a learnable scaling weight or not.

  • bias (bool) – Whether to use a learnable bias or not.

  • eps (float) – Epsilon value for numerical stability.

  • dtype (jax.numpy.dtype) – Data type of the norm. Note that the statistic reductions in the norms are forced to be float32.

  • norm_type (NormType) – Type of the norm layer. Currently supported types are “layernorm” and “rmsnorm”.

  • model_axis_name (str | None) – Name of the model axis to shard over. If None, no sharding is performed.

  • **kwargs – Additional keyword arguments for the norm layer.

Returns:

Norm layer.

Return type:

flax.linen.Module

xlstm_jax.models.xlstm_parallel.components.normalization.MultiHeadNormLayer(weight=True, bias=False, eps=1e-05, dtype=jnp.float32, axis=1, norm_type='layernorm', model_axis_name=None, **kwargs)#

Create a multi-head norm layer.

Effectively vmaps a norm layer over the specified axis.

Parameters:
  • weight (bool) – Whether to use a learnable scaling weight or not.

  • bias (bool) – Whether to use a learnable bias or not.

  • eps (float) – Epsilon value for numerical stability.

  • dtype (jax.numpy.dtype) – Data type of the norm. Note that the statistic reductions in the norms are forced to be float32.

  • axis (int) – Axis to vmap the norm layer over, i.e. the head axis. The normalization is always performed over the last axis.

  • norm_type (NormType) – Type of the norm layer. Currently supported types are “layernorm” and “rmsnorm”.

  • model_axis_name (str | None) – Name of the model axis to shard over. If None, no sharding is performed.

  • **kwargs – Additional keyword arguments for the norm layer.

Returns:

Multi-head norm layer.

Return type:

flax.linen.Module

xlstm_jax.models.xlstm_parallel.components.normalization.resolve_norm(norm_type, weight=True, bias=False, eps=1e-05, dtype=jnp.float32, **kwargs)#

Resolve the norm layer based on the norm type.

Parameters:
  • norm_type (NormType) – Type of the norm layer. Currently supported types are “layernorm” and “rmsnorm”.

  • weight (bool) – Whether to use a learnable scaling weight or not.

  • bias (bool) – Whether to use a learnable bias or not.

  • eps (float) – Epsilon value for numerical stability.

  • dtype (jax.numpy.dtype) – Data type of the norm. Note that the statistic reductions in the norms are forced to be float32.

  • **kwargs – Additional keyword arguments.

Returns:

Tuple of the norm class and the keyword arguments.

Return type:

tuple[Any, dict[str, Any]]