xlstm_jax.models.xlstm_parallel.components.normalization#
Attributes#
Functions#
|
Create a norm layer. |
|
Create a multi-head norm layer. |
|
Resolve the norm layer based on the norm type. |
Module Contents#
- xlstm_jax.models.xlstm_parallel.components.normalization.LOGGER#
- xlstm_jax.models.xlstm_parallel.components.normalization.NormType#
- xlstm_jax.models.xlstm_parallel.components.normalization.NormLayer(weight=True, bias=False, eps=1e-05, dtype=jnp.float32, norm_type='layernorm', model_axis_name=None, **kwargs)#
Create a norm layer.
- Parameters:
weight (bool) – Whether to use a learnable scaling weight or not.
bias (bool) – Whether to use a learnable bias or not.
eps (float) – Epsilon value for numerical stability.
dtype (jax.numpy.dtype) – Data type of the norm. Note that the statistic reductions in the norms are forced to be float32.
norm_type (NormType) – Type of the norm layer. Currently supported types are “layernorm” and “rmsnorm”.
model_axis_name (str | None) – Name of the model axis to shard over. If None, no sharding is performed.
**kwargs – Additional keyword arguments for the norm layer.
- Returns:
Norm layer.
- Return type:
flax.linen.Module
- xlstm_jax.models.xlstm_parallel.components.normalization.MultiHeadNormLayer(weight=True, bias=False, eps=1e-05, dtype=jnp.float32, axis=1, norm_type='layernorm', model_axis_name=None, **kwargs)#
Create a multi-head norm layer.
Effectively vmaps a norm layer over the specified axis.
- Parameters:
weight (bool) – Whether to use a learnable scaling weight or not.
bias (bool) – Whether to use a learnable bias or not.
eps (float) – Epsilon value for numerical stability.
dtype (jax.numpy.dtype) – Data type of the norm. Note that the statistic reductions in the norms are forced to be float32.
axis (int) – Axis to vmap the norm layer over, i.e. the head axis. The normalization is always performed over the last axis.
norm_type (NormType) – Type of the norm layer. Currently supported types are “layernorm” and “rmsnorm”.
model_axis_name (str | None) – Name of the model axis to shard over. If None, no sharding is performed.
**kwargs – Additional keyword arguments for the norm layer.
- Returns:
Multi-head norm layer.
- Return type:
flax.linen.Module
- xlstm_jax.models.xlstm_parallel.components.normalization.resolve_norm(norm_type, weight=True, bias=False, eps=1e-05, dtype=jnp.float32, **kwargs)#
Resolve the norm layer based on the norm type.
- Parameters:
norm_type (NormType) – Type of the norm layer. Currently supported types are “layernorm” and “rmsnorm”.
weight (bool) – Whether to use a learnable scaling weight or not.
bias (bool) – Whether to use a learnable bias or not.
eps (float) – Epsilon value for numerical stability.
dtype (jax.numpy.dtype) – Data type of the norm. Note that the statistic reductions in the norms are forced to be float32.
**kwargs – Additional keyword arguments.
- Returns:
Tuple of the norm class and the keyword arguments.
- Return type: