xlstm_jax.models.xlstm_parallel.components.init

xlstm_jax.models.xlstm_parallel.components.init#

Functions#

bias_linspace_init(start, end[, axis_name])

Linearly spaced bias init across dimensions.

Module Contents#

xlstm_jax.models.xlstm_parallel.components.init.bias_linspace_init(start, end, axis_name=None)#

Linearly spaced bias init across dimensions.

Only supports 1D array shapes. Array values are including start and end. If axis name is provided, the linspace is sharded over the axis.

Parameters:
  • start (float) – Start value for the linspace.

  • end (float) – End value for the linspace.

  • axis_name (str | None) – Optional axis name to shard over.

Returns:

Initializer function that creates a 1D array with linearly spaced values between start and end.

Return type:

jax.nn.initializers.Initializer