xlstm_jax.models.shared.utils#
Functions#
|
Remats and shards layer if needed. |
|
Soft caps logits to a value. |
Module Contents#
- xlstm_jax.models.shared.utils.prepare_module(layer, layer_name, config)#
Remats and shards layer if needed.
This function wraps the layer function in a remat and/or sharding function if its layer name is present in the remat and fsdp configuration, respectively.
- Parameters:
layer (collections.abc.Callable[Ellipsis, flax.linen.Module]) – The layer to prepare.
layer_name (str) – The name of the layer.
config (xlstm_jax.models.configs.ParallelConfig | None) – The configuration to use.
- Returns:
The layer with remat and sharding applied if needed.
- Return type:
collections.abc.Callable[Ellipsis, flax.linen.Module]
- xlstm_jax.models.shared.utils.soft_cap_logits(logits, cap_value)#
Soft caps logits to a value.
Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention and output language heads to prevent large logits from dominating the softmax. See for example Gemma2: https://arxiv.org/abs/2408.00118