xlstm_jax.models.xlstm_parallel.benchmark

xlstm_jax.models.xlstm_parallel.benchmark#

Functions#

init_mesh([data_axis_size, fsdp_axis_size, ...])

create_batch(batch_size, context_length, vocab_size, ...)

benchmark_model(config[, data_axis_size, ...])

Module Contents#

xlstm_jax.models.xlstm_parallel.benchmark.init_mesh(data_axis_size=-1, fsdp_axis_size=1, pipeline_axis_size=1, model_axis_size=1, data_axis_name='dp', fsdp_axis_name='fsdp', pipeline_axis_name='pp', model_axis_name='tp')#
Parameters:
  • data_axis_size (int)

  • fsdp_axis_size (int)

  • pipeline_axis_size (int)

  • model_axis_size (int)

  • data_axis_name (str)

  • fsdp_axis_name (str)

  • pipeline_axis_name (str)

  • model_axis_name (str)

Return type:

jax.sharding.Mesh

xlstm_jax.models.xlstm_parallel.benchmark.create_batch(batch_size, context_length, vocab_size, rng, mesh, config)#
Parameters:
xlstm_jax.models.xlstm_parallel.benchmark.benchmark_model(config, data_axis_size=-1, fsdp_axis_size=1, pipeline_axis_size=1, model_axis_size=1, seed=42, gradient_accumulate_steps=1, batch_size_per_device=32, optimizer=None, log_dir=None, log_num_steps=1, log_skip_steps=5, num_steps=100)#
Parameters: