xlstm_jax.models.xlstm_parallel.benchmark#
Functions#
|
|
|
|
|
Module Contents#
- xlstm_jax.models.xlstm_parallel.benchmark.init_mesh(data_axis_size=-1, fsdp_axis_size=1, pipeline_axis_size=1, model_axis_size=1, data_axis_name='dp', fsdp_axis_name='fsdp', pipeline_axis_name='pp', model_axis_name='tp')#
- xlstm_jax.models.xlstm_parallel.benchmark.create_batch(batch_size, context_length, vocab_size, rng, mesh, config)#
- Parameters:
batch_size (int)
context_length (int)
vocab_size (int)
rng (jax.Array)
mesh (jax.sharding.Mesh)
- xlstm_jax.models.xlstm_parallel.benchmark.benchmark_model(config, data_axis_size=-1, fsdp_axis_size=1, pipeline_axis_size=1, model_axis_size=1, seed=42, gradient_accumulate_steps=1, batch_size_per_device=32, optimizer=None, log_dir=None, log_num_steps=1, log_skip_steps=5, num_steps=100)#
- Parameters:
config (xlstm_jax.models.xlstm_parallel.xlstm_lm_model.xLSTMLMModelConfig)
data_axis_size (int)
fsdp_axis_size (int)
pipeline_axis_size (int)
model_axis_size (int)
seed (int)
gradient_accumulate_steps (int)
batch_size_per_device (int)
optimizer (Any | None)
log_dir (str | None)
log_num_steps (int)
log_skip_steps (int)
num_steps (int)