xlstm_jax.models.xlstm_parallel.components.init#
Functions#
|
Linearly spaced bias init across dimensions. |
Module Contents#
- xlstm_jax.models.xlstm_parallel.components.init.bias_linspace_init(start, end, axis_name=None)#
Linearly spaced bias init across dimensions.
Only supports 1D array shapes. Array values are including start and end. If axis name is provided, the linspace is sharded over the axis.
- Parameters:
- Returns:
Initializer function that creates a 1D array with linearly spaced values between start and end.
- Return type:
jax.nn.initializers.Initializer