xlstm_jax.models.shared.init#

Attributes#

Functions#

small_init(dim[, distribution])

Create initializer of Nguyen et al. (2019).

wang_init(dim, num_blocks[, distribution])

Create Wang initializer.

create_common_init_fn(fn_name, dim, num_blocks[, ...])

Create common initializer function.

_dist_from_stddev(stddev, distribution)

Create initializer with specified standard deviation and distribution.

uniform_init(min_val, max_val)

Create uniform initializer.

Module Contents#

xlstm_jax.models.shared.init.InitDistribution#
xlstm_jax.models.shared.init.InitFnName#
xlstm_jax.models.shared.init.small_init(dim, distribution='normal')#

Create initializer of Nguyen et al. (2019).

Adopted from EleutherAI/gpt-neox. The initializer creates an array with values according to the method described in: “Transformers without Tears: Improving the Normalization of Self-Attention”, Nguyen, T. & Salazar, J. (2019). The array values are sampled with a standard deviation of sqrt(2 / (5 * dim)).

Parameters:
  • dim (int) – Feature dimensionality to use in the initializer.

  • distribution (InitDistribution) – The distribution to sample from. Supported are normal, truncated normal, and uniform.

Returns:

Initializer function following the above described method.

Return type:

jax.nn.initializers.Initializer

xlstm_jax.models.shared.init.wang_init(dim, num_blocks, distribution='normal')#

Create Wang initializer.

Adopted from EleutherAI/gpt-neox. Commonly used for the output layers of residual blocks. The array values are sampled with a standard deviation of 2 / num_blocks / sqrt(dim).

Parameters:
  • dim (int) – Feature dimensionality to use in the initializer.

  • num_blocks (int) – Number of layers / blocks in the model.

  • distribution (InitDistribution) – The distribution to sample from. Supported are normal, truncated normal, and uniform.

Returns:

Initializer function of the wang init.

Return type:

jax.nn.initializers.Initializer

xlstm_jax.models.shared.init.create_common_init_fn(fn_name, dim, num_blocks, distribution='normal')#

Create common initializer function.

Allows to create different types of initializers with a single function call.

Parameters:
  • fn_name (InitFnName) – Name of the initializer function to create. Supported are “small” (small_init()), “wang” (wang_init()), “wang2” (wang_init() with 2x block num), and “zeros” (zero initializer).

  • dim (int) – Feature dimensionality to use in the initializer.

  • num_blocks (int) – Number of layers / blocks in the model.

  • distribution (InitDistribution) – The distribution to sample from. Supported are normal, truncated normal, and uniform.

Returns:

Initializer function of the specified type.

Return type:

jax.nn.initializers.Initializer

xlstm_jax.models.shared.init._dist_from_stddev(stddev, distribution)#

Create initializer with specified standard deviation and distribution.

The distribution has a zero mean and specified standard deviation.

Parameters:
  • stddev (float) – The standard deviation of the distribution.

  • distribution (InitDistribution) – The distribution to sample from. Supported are normal, truncated normal, and uniform.

Returns:

Initializer function that samples the array value from the specified distribution with the given standard deviation.

Return type:

jax.nn.initializers.Initializer

xlstm_jax.models.shared.init.uniform_init(min_val, max_val)#

Create uniform initializer.

Parameters:
  • min_val (float) – Minimum value of the uniform distribution.

  • max_val (float) – Maximum value of the uniform distribution.

Returns:

An initializer function which samples values randomly between min_val and max_val.

Return type:

jax.nn.initializers.Initializer