xlstm_jax.kernels.kernel_utils#
Attributes#
Functions#
|
Converts a JAX dtype to a Triton dtype. |
Module Contents#
- xlstm_jax.kernels.kernel_utils._jax_to_triton_dtype#
- xlstm_jax.kernels.kernel_utils.jax2triton_dtype(dtype)#
Converts a JAX dtype to a Triton dtype.
- Parameters:
dtype (jax.numpy.dtype | str) – JAX dtype.
- Returns:
Triton dtype.
- Return type:
triton.language.dtype