xlstm_jax.models.shared.lm_head#

Classes#

TPLMHead

Language model head with Tensor Parallelism.

Module Contents#

class xlstm_jax.models.shared.lm_head.TPLMHead#

Bases: flax.linen.Module

Language model head with Tensor Parallelism.

Parameters:
  • parallel – Configuration for parallelism.

  • vocab_size – Size of the vocabulary.

  • kernel_init – Initializer for the output layer.

  • norm_fn – Normalization function to apply before the output layer. If None, no normalization is applied.

  • lm_head_dtype – Data type for the output layer.

  • logits_soft_cap – Soft cap for the logits. If not None, the logits will be clipped to this value.

parallel: xlstm_jax.models.configs.ParallelConfig#
vocab_size: int#
kernel_init: flax.linen.initializers.Initializer#
norm_fn: collections.abc.Callable[Ellipsis, flax.linen.Module] | None#
lm_head_dtype: jax.numpy.dtype#
logits_soft_cap: float | None = None#