xlstm_jax.models.llama.feedforward#

Classes#

FeedForwardConfig

Configuration for the feedforward network.

FeedForward

Module Contents#

class xlstm_jax.models.llama.feedforward.FeedForwardConfig#

Bases: xlstm_jax.models.configs.SubModelConfig

Configuration for the feedforward network.

multiple_of: int = 64#

The hidden dimension of the feedforward network will be increased to a multiple of this value. This is useful for ensuring an efficient use of the hardware, e.g. for tensor cores.

ffn_dim_multiplier: float = 1.0#

Multiplier for the hidden dimension of the feedforward network. By default, the hidden dimension is up to 8/3 of the input dimension. This multiplier is applied to this default size and can be used to increase or decrease the hidden dimension. This is in line with the original PyTorch Llama implementation.

use_bias: bool = False#

Whether to use bias in the feedforward network.

dropout_rate: float = 0.0#

Dropout rate for the feedforward network.

num_layers: int = 12#

Number of layers in the whole Llama Transformer model. Used for initialization.

dtype: str = 'float32'#

Data type of the activations in the network.

parallel: xlstm_jax.models.configs.ParallelConfig#

Parallel configuration.

property _dtype: jax.numpy.dtype#

Returns the real dtype instead of the str from configs.

Returns:

The jnp dtype corresponding to the string value.

Return type:

jax.numpy.dtype

to_dict()#

Converts the config to a dictionary.

Helpful for saving to disk or logging.

Return type:

dict

class xlstm_jax.models.llama.feedforward.FeedForward#

Bases: flax.linen.Module

config: FeedForwardConfig#