xlstm_jax.dataset.grain_batch_rampup#

Batch rampup schedule for grain IterDatasets.

This module provides a BatchRampUpIterDataset that allows for a batch rampup schedule to be provided. The batch rampup schedule is a function that takes the current batch step and returns the batch size. It can be used to gradually increase the batch size over time.

The implementation is based on the standard BatchIterDataset from grain, with the addition of a batch rampup schedule. See grain._src.python.dataset.transformations.batch for more details.

NOTE: If the grain API for batching changes, this module may need to be updated.

Attributes#

Classes#

_BatchRampUpDatasetIterator

Iterator that batches elements with a batch rampup schedule.

BatchRampUpIterDataset

Batch transformation with ramp up for IterDatasets.

Functions#

_make_batch(values)

Returns a batch of values with a new batch dimension at the front.

batch_dataset_with_rampup(parent, batch_size[, ...])

Creates a BatchRampUpIterDataset from an IterDataset.

create_batch_rampup_schedule(batch_size, schedule_type)

Creates a batch rampup schedule.

constant_rampup_schedule(batch_size)

Returns a constant batch rampup schedule.

stepwise_rampup_schedule(batch_size, boundaries_and_scales)

Returns a stepwise batch rampup schedule.

Module Contents#

xlstm_jax.dataset.grain_batch_rampup.T#
xlstm_jax.dataset.grain_batch_rampup.S#
xlstm_jax.dataset.grain_batch_rampup.LOGGER#
xlstm_jax.dataset.grain_batch_rampup._make_batch(values)#

Returns a batch of values with a new batch dimension at the front.

Parameters:

values (collections.abc.Sequence[T])

Return type:

T

class xlstm_jax.dataset.grain_batch_rampup._BatchRampUpDatasetIterator(parent, batch_rampup_schedule, drop_remainder, batch_fn, stats)#

Bases: grain._src.python.dataset.dataset.DatasetIterator[T]

Iterator that batches elements with a batch rampup schedule.

Parameters:
_parent#
_batch_rampup_schedule#
_drop_remainder#
_batch_fn#
_batch_step = 0#
get_state()#

Return the state of the iterator.

Return type:

dict[str, Any]

set_state(state)#

Set the state of the iterator.

Parameters:

state (dict[str, Any])

class xlstm_jax.dataset.grain_batch_rampup.BatchRampUpIterDataset(parent, batch_rampup_schedule, drop_remainder=False, batch_fn=None)#

Bases: grain._src.python.dataset.dataset.IterDataset[T]

Batch transformation with ramp up for IterDatasets.

Parameters:
_batch_rampup_schedule#
_drop_remainder = False#
_batch_fn#
xlstm_jax.dataset.grain_batch_rampup.batch_dataset_with_rampup(parent, batch_size, drop_remainder=False, batch_fn=None, schedule_type='stepwise', boundaries_and_scales=None)#

Creates a BatchRampUpIterDataset from an IterDataset.

Parameters:
  • parent (grain._src.python.dataset.dataset.IterDataset[S]) – The parent IterDataset whose elements are batched.

  • batch_size (int) – The initial batch size.

  • drop_remainder (bool) – Whether to drop the last batch if it is smaller than batch_size.

  • batch_fn (collections.abc.Callable[[collections.abc.Sequence[S]], T] | None) – A function that takes a list of elements and returns a batch. Defaults to stacking the elements along a new batch dimension.

  • schedule_type (str) – The type of the batch rampup schedule. Supported types are “constant” and “stepwise”.

  • boundaries_and_scales (dict[str, float] | None) – Used only for the “stepwise” schedule type. A dictionary mapping the boundaries b_i to non-negative scaling factors f_i. For any step count s, the schedule returns batch_size scaled by the product of factor f_i for the largest b_i such that b_i < s.

Returns:

A BatchRampUpIterDataset that batches elements. If no schedule is provided, falls back to the standard BatchIterDataset.

Return type:

BatchRampUpIterDataset[T] | grain._src.python.dataset.dataset.IterDataset[T]

xlstm_jax.dataset.grain_batch_rampup.create_batch_rampup_schedule(batch_size, schedule_type, boundaries_and_scales=None)#

Creates a batch rampup schedule.

Parameters:
  • batch_size (int) – The initial batch size.

  • schedule_type (str) – The type of the batch rampup schedule. Supported types are “constant” and “stepwise”.

  • boundaries_and_scales (dict[str, float] | None) – A dictionary mapping the boundaries b_i to non-negative scaling factors f_i. For any step count s, the schedule returns batch_size scaled by the product of factor f_i for the largest b_i such that b_i < s. Only required for the “stepwise” schedule.

Returns:

A function that takes the current batch step and returns the batch size.

Return type:

collections.abc.Callable[[int], int]

xlstm_jax.dataset.grain_batch_rampup.constant_rampup_schedule(batch_size)#

Returns a constant batch rampup schedule.

Parameters:

batch_size (int) – The constant batch size.

Returns:

A function that takes the current batch step and returns the batch size.

Return type:

collections.abc.Callable[[int], int]

xlstm_jax.dataset.grain_batch_rampup.stepwise_rampup_schedule(batch_size, boundaries_and_scales)#

Returns a stepwise batch rampup schedule.

Parameters:
  • batch_size (int) – The initial batch size on which the factors are applied.

  • boundaries_and_scales (dict[int, float]) – A dictionary mapping the boundaries b_i to non-negative scaling factors f_i. For any step count s, the schedule returns batch_size scaled by the product of factor f_i for the largest b_i such that b_i < s.

Returns:

A function that takes the current batch step and returns the batch size.

Return type:

collections.abc.Callable[[int], int]