xlstm_jax.kernels.stride_utils

xlstm_jax.kernels.stride_utils#

Functions#

get_strides(array)

Returns the strides of a JAX array.

get_stride(array, axis)

Returns the stride of a JAX array at a given axis.

Module Contents#

xlstm_jax.kernels.stride_utils.get_strides(array)#

Returns the strides of a JAX array.

Parameters:

array (jax.Array | jax.ShapeDtypeStruct) – JAX array or shape-dtype struct.

Returns:

The strides of the array. Length is equal to the number of dimensions.

Return type:

list[int]

xlstm_jax.kernels.stride_utils.get_stride(array, axis)#

Returns the stride of a JAX array at a given axis.

To calculate all strides, use get_strides.

Parameters:
Returns:

The stride of the array at the given axis.

Return type:

int