xlstm_jax.distributed.xla_utils#
Functions#
|
Simulate a CPU with a given number of devices. |
Set XLA flags for better performance. |
Module Contents#
- xlstm_jax.distributed.xla_utils.simulate_CPU_devices(device_count=8)#
Simulate a CPU with a given number of devices.
- Parameters:
device_count (int) – The number of devices to simulate.
- xlstm_jax.distributed.xla_utils.set_XLA_flags()#
Set XLA flags for better performance.
For performance flags, see https://jax.readthedocs.io/en/latest/gpu_performance_tips.html and NVIDIA/JAX-Toolbox.