GradientAccumulationScheduler¶
- class lightning.pytorch.callbacks.GradientAccumulationScheduler(scheduling, mode='epoch')[source]¶
Bases:
CallbackChange gradient accumulation factor according to scheduling.
- Parameters:
scheduling¶ (
dict[int,int]) – Scheduling in format{threshold: accumulation_factor}. Whenmode="epoch", keys are zero-indexed epoch numbers. Whenmode="step", keys are global step numbers.mode¶ (
Literal['epoch','step']) – Whether to schedule by"epoch"or"step". Defaults to"epoch"for backward compatibility.
Note
The argument scheduling is a dictionary. When
mode="epoch", each key represents an epoch and its associated accumulation factor value (epochs are zero-indexed). Whenmode="step", each key represents a global training step. For example, if you want to change the accumulation factor after 4 epochs, usescheduling={4: factor}withmode="epoch"; for step-based scheduling use e.g.scheduling={0: 8, 1000: 4, 5000: 1}withmode="step".- Raises:
TypeError – If
schedulingis an emptydict, or not all keys and values ofschedulingare integers.MisconfigurationException – If
modeis not"epoch"or"step", or if keys/values are invalid.IndexError – If minimal threshold is less than 0.
Example:
>>> from lightning.pytorch import Trainer >>> from lightning.pytorch.callbacks import GradientAccumulationScheduler # Epoch-based: from epoch 5, accumulate every 2 batches (use 4 for zero-indexed). >>> accumulator = GradientAccumulationScheduler(scheduling={4: 2}) >>> trainer = Trainer(callbacks=[accumulator]) # Step-based: for single-epoch pretraining, schedule by global step. >>> accumulator = GradientAccumulationScheduler( ... scheduling={0: 8, 1000: 4, 5000: 1}, ... mode="step", ... ) >>> trainer = Trainer(callbacks=[accumulator])