xlstm_jax.models.shared.utils

xlstm_jax.models.shared.utils#

Functions#

prepare_module(layer, layer_name, config)

Remats and shards layer if needed.

soft_cap_logits(logits, cap_value)

Soft caps logits to a value.

Module Contents#

xlstm_jax.models.shared.utils.prepare_module(layer, layer_name, config)#

Remats and shards layer if needed.

This function wraps the layer function in a remat and/or sharding function if its layer name is present in the remat and fsdp configuration, respectively.

Parameters:
Returns:

The layer with remat and sharding applied if needed.

Return type:

collections.abc.Callable[Ellipsis, flax.linen.Module]

xlstm_jax.models.shared.utils.soft_cap_logits(logits, cap_value)#

Soft caps logits to a value.

Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention and output language heads to prevent large logits from dominating the softmax. See for example Gemma2: https://arxiv.org/abs/2408.00118

Parameters:
  • logits (jax.Array) – The logits to cap.

  • cap_value (float | jax.Array) – The value to cap logits to. If None, no cap is applied.

Returns:

The capped logits.

Return type:

jax.Array