xlstm_jax.models.shared.init#
Attributes#
Functions#
|
Create initializer of Nguyen et al. (2019). |
|
Create Wang initializer. |
|
Create common initializer function. |
|
Create initializer with specified standard deviation and distribution. |
|
Create uniform initializer. |
Module Contents#
- xlstm_jax.models.shared.init.InitDistribution#
- xlstm_jax.models.shared.init.InitFnName#
- xlstm_jax.models.shared.init.small_init(dim, distribution='normal')#
Create initializer of Nguyen et al. (2019).
Adopted from EleutherAI/gpt-neox. The initializer creates an array with values according to the method described in: “Transformers without Tears: Improving the Normalization of Self-Attention”, Nguyen, T. & Salazar, J. (2019). The array values are sampled with a standard deviation of sqrt(2 / (5 * dim)).
- Parameters:
dim (int) – Feature dimensionality to use in the initializer.
distribution (InitDistribution) – The distribution to sample from. Supported are normal, truncated normal, and uniform.
- Returns:
Initializer function following the above described method.
- Return type:
jax.nn.initializers.Initializer
- xlstm_jax.models.shared.init.wang_init(dim, num_blocks, distribution='normal')#
Create Wang initializer.
Adopted from EleutherAI/gpt-neox. Commonly used for the output layers of residual blocks. The array values are sampled with a standard deviation of 2 / num_blocks / sqrt(dim).
- Parameters:
- Returns:
Initializer function of the wang init.
- Return type:
jax.nn.initializers.Initializer
- xlstm_jax.models.shared.init.create_common_init_fn(fn_name, dim, num_blocks, distribution='normal')#
Create common initializer function.
Allows to create different types of initializers with a single function call.
- Parameters:
fn_name (InitFnName) – Name of the initializer function to create. Supported are “small” (
small_init()), “wang” (wang_init()), “wang2” (wang_init()with 2x block num), and “zeros” (zero initializer).dim (int) – Feature dimensionality to use in the initializer.
num_blocks (int) – Number of layers / blocks in the model.
distribution (InitDistribution) – The distribution to sample from. Supported are normal, truncated normal, and uniform.
- Returns:
Initializer function of the specified type.
- Return type:
jax.nn.initializers.Initializer
- xlstm_jax.models.shared.init._dist_from_stddev(stddev, distribution)#
Create initializer with specified standard deviation and distribution.
The distribution has a zero mean and specified standard deviation.
- Parameters:
stddev (float) – The standard deviation of the distribution.
distribution (InitDistribution) – The distribution to sample from. Supported are normal, truncated normal, and uniform.
- Returns:
Initializer function that samples the array value from the specified distribution with the given standard deviation.
- Return type:
jax.nn.initializers.Initializer
- xlstm_jax.models.shared.init.uniform_init(min_val, max_val)#
Create uniform initializer.