xlstm_jax.models.llama.feedforward#
Classes#
Configuration for the feedforward network. |
|
Module Contents#
- class xlstm_jax.models.llama.feedforward.FeedForwardConfig#
Bases:
xlstm_jax.models.configs.SubModelConfigConfiguration for the feedforward network.
- multiple_of: int = 64#
The hidden dimension of the feedforward network will be increased to a multiple of this value. This is useful for ensuring an efficient use of the hardware, e.g. for tensor cores.
- ffn_dim_multiplier: float = 1.0#
Multiplier for the hidden dimension of the feedforward network. By default, the hidden dimension is up to 8/3 of the input dimension. This multiplier is applied to this default size and can be used to increase or decrease the hidden dimension. This is in line with the original PyTorch Llama implementation.
- num_layers: int = 12#
Number of layers in the whole Llama Transformer model. Used for initialization.
- parallel: xlstm_jax.models.configs.ParallelConfig#
Parallel configuration.
- property _dtype: jax.numpy.dtype#
Returns the real dtype instead of the str from configs.
- Returns:
The jnp dtype corresponding to the string value.
- Return type:
- class xlstm_jax.models.llama.feedforward.FeedForward#
Bases:
flax.linen.Module- config: FeedForwardConfig#