xlstm_jax.distributed.xla_utils

xlstm_jax.distributed.xla_utils#

Functions#

simulate_CPU_devices([device_count])

Simulate a CPU with a given number of devices.

set_XLA_flags()

Set XLA flags for better performance.

Module Contents#

xlstm_jax.distributed.xla_utils.simulate_CPU_devices(device_count=8)#

Simulate a CPU with a given number of devices.

Parameters:

device_count (int) – The number of devices to simulate.

xlstm_jax.distributed.xla_utils.set_XLA_flags()#

Set XLA flags for better performance.

For performance flags, see https://jax.readthedocs.io/en/latest/gpu_performance_tips.html and NVIDIA/JAX-Toolbox.