GradientAccumulationScheduler

class lightning.pytorch.callbacks.GradientAccumulationScheduler(scheduling, mode='epoch')[source]

Bases: Callback

Change gradient accumulation factor according to scheduling.

Parameters:
  • scheduling (dict[int, int]) – Scheduling in format {threshold: accumulation_factor}. When mode="epoch", keys are zero-indexed epoch numbers. When mode="step", keys are global step numbers.

  • mode (Literal['epoch', 'step']) – Whether to schedule by "epoch" or "step". Defaults to "epoch" for backward compatibility.

Note

The argument scheduling is a dictionary. When mode="epoch", each key represents an epoch and its associated accumulation factor value (epochs are zero-indexed). When mode="step", each key represents a global training step. For example, if you want to change the accumulation factor after 4 epochs, use scheduling={4: factor} with mode="epoch"; for step-based scheduling use e.g. scheduling={0: 8, 1000: 4, 5000: 1} with mode="step".

Raises:
  • TypeError – If scheduling is an empty dict, or not all keys and values of scheduling are integers.

  • MisconfigurationException – If mode is not "epoch" or "step", or if keys/values are invalid.

  • IndexError – If minimal threshold is less than 0.

Example:

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import GradientAccumulationScheduler

# Epoch-based: from epoch 5, accumulate every 2 batches (use 4 for zero-indexed).
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
>>> trainer = Trainer(callbacks=[accumulator])

# Step-based: for single-epoch pretraining, schedule by global step.
>>> accumulator = GradientAccumulationScheduler(
...     scheduling={0: 8, 1000: 4, 5000: 1},
...     mode="step",
... )
>>> trainer = Trainer(callbacks=[accumulator])
on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]

Called when the train batch begins.

Return type:

None

on_train_epoch_start(trainer, *_)[source]

Called when the train epoch begins.

Return type:

None

on_train_start(trainer, pl_module)[source]

Performns a configuration validation before training starts and raises errors for incompatible settings.

Return type:

None