xlstm_jax.models.shared.lm_head#
Classes#
Language model head with Tensor Parallelism. |
Module Contents#
- class xlstm_jax.models.shared.lm_head.TPLMHead#
Bases:
flax.linen.ModuleLanguage model head with Tensor Parallelism.
- Parameters:
parallel – Configuration for parallelism.
vocab_size – Size of the vocabulary.
kernel_init – Initializer for the output layer.
norm_fn – Normalization function to apply before the output layer. If None, no normalization is applied.
lm_head_dtype – Data type for the output layer.
logits_soft_cap – Soft cap for the logits. If not None, the logits will be clipped to this value.
- parallel: xlstm_jax.models.configs.ParallelConfig#
- kernel_init: flax.linen.initializers.Initializer#
- norm_fn: collections.abc.Callable[Ellipsis, flax.linen.Module] | None#
- lm_head_dtype: jax.numpy.dtype#