xlstm_jax.kernels.kernel_utils#

Attributes#

Functions#

jax2triton_dtype(dtype)

Converts a JAX dtype to a Triton dtype.

Module Contents#

xlstm_jax.kernels.kernel_utils._jax_to_triton_dtype#
xlstm_jax.kernels.kernel_utils.jax2triton_dtype(dtype)#

Converts a JAX dtype to a Triton dtype.

Parameters:

dtype (jax.numpy.dtype | str) – JAX dtype.

Returns:

Triton dtype.

Return type:

triton.language.dtype