xlstm_jax.kernels.stride_utils#
Functions#
|
Returns the strides of a JAX array. |
|
Returns the stride of a JAX array at a given axis. |
Module Contents#
- xlstm_jax.kernels.stride_utils.get_strides(array)#
Returns the strides of a JAX array.
- Parameters:
array (jax.Array | jax.ShapeDtypeStruct) – JAX array or shape-dtype struct.
- Returns:
The strides of the array. Length is equal to the number of dimensions.
- Return type:
- xlstm_jax.kernels.stride_utils.get_stride(array, axis)#
Returns the stride of a JAX array at a given axis.
To calculate all strides, use get_strides.
- Parameters:
array (jax.Array | jax.ShapeDtypeStruct) – JAX array or shape-dtype struct.
axis (int) – The axis at which to calculate the stride.
- Returns:
The stride of the array at the given axis.
- Return type: