xlstm_jax.models.xlstm_parallel.checkpointing

xlstm_jax.models.xlstm_parallel.checkpointing#

Functions#

save_checkpoint(state, log_dir)

Module Contents#

xlstm_jax.models.xlstm_parallel.checkpointing.save_checkpoint(state, log_dir)#