Learning Rate Schedulers

DeepSpeed offers implementations of LRRangeTest, OneCycle, WarmupLR, WarmupDecayLR, WarmupCosineLR learning rate schedulers. When using a DeepSpeed’s learning rate scheduler (specified in the ds_config.json file), DeepSpeed calls the step() method of the scheduler at every training step (when model_engine.step() is executed). When not using a DeepSpeed’s learning rate scheduler:
  • if the schedule is supposed to execute at every training step, then the user can pass the scheduler to deepspeed.initialize when initializing the DeepSpeed engine and let DeepSpeed manage it for update or save/restore.

  • if the schedule is supposed to execute at any other interval (e.g., training epochs), then the user should NOT pass the scheduler to DeepSpeed during initialization and must manage it explicitly.

LRRangeTest

class deepspeed.runtime.lr_schedules.LRRangeTest(optimizer: Optimizer, lr_range_test_min_lr: float = 0.001, lr_range_test_step_size: int = 2000, lr_range_test_step_rate: float = 1.0, lr_range_test_staircase: bool = False, last_batch_iteration: int = -1)[source]

Sets the learning rate of each parameter group according to learning rate range test (LRRT) policy. The policy increases learning rate starting from a base value with a constant frequency, as detailed in the paper `A disciplined approach to neural network hyper-parameters: Part1`_.

LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to configure the LR boundaries for Cyclic LR schedules.

LRRT changes the learning rate after every batch. step should be called after a batch has been used for training.

Parameters
  • optimizer (Optimizer) – Wrapped optimizer.

  • lr_range_test_min_lr (float or list) – Initial learning rate which is the lower boundary in the range test for each parameter group.

  • lr_range_test_step_size (int) – Interval of training steps to increase learning rate. Default: 2000

  • lr_range_test_step_rate (float) – Scaling rate for range test. Default: 1.0

  • lr_range_test_staircase (bool) – Scale in staircase fashion, rather than continuous. Default: False.

  • last_batch_iteration (int) – The index of the last batch. This parameter is used when resuming a training job. Since step() should be invoked after each batch instead of after each epoch, this number represents the total number of batches computed, not the total number of epochs computed. When last_batch_iteration=-1, the schedule is started from the beginning. Default: -1

Example

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = LRRangeTest(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>>     for batch in data_loader:
>>>         train_batch(...)
>>>         scheduler.step()

_A disciplined approach to neural network hyper-parameters: Part 1 – learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820

OneCycle

class deepspeed.runtime.lr_schedules.OneCycle(optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate=0.0, cycle_first_step_size=2000, cycle_second_step_size=None, cycle_first_stair_count=0, cycle_second_stair_count=None, decay_step_size=0, cycle_momentum=True, cycle_min_mom=0.8, cycle_max_mom=0.9, decay_mom_rate=0.0, last_batch_iteration=-1)[source]

Sets the learning rate of each parameter group according to 1Cycle learning rate policy (1CLR). 1CLR is a variation of the Cyclical Learning Rate (CLR) policy that involves one cycle followed by decay. The policy simultaneously cycles the learning rate (and momentum) between two boundaries with a constant frequency, as detailed in the paper A disciplined approach to neural network hyper-parameters.

1CLR policy changes the learning rate after every batch. step should be called after a batch has been used for training.

This implementation was adapted from the github repo: `pytorch/pytorch`_

Parameters
  • optimizer (Optimizer) – Wrapped optimizer.

  • cycle_min_lr (float or list) – Initial learning rate which is the lower boundary in the cycle for each parameter group.

  • cycle_max_lr (float or list) – Upper learning rate boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (cycle_max_lr - cycle_min_lr). The lr at any cycle is the sum of cycle_min_lr and some scaling of the amplitude; therefore cycle_max_lr may not actually be reached depending on scaling function.

  • decay_lr_rate (float) – Decay rate for learning rate. Default: 0.

  • cycle_first_step_size (int) – Number of training iterations in the increasing half of a cycle. Default: 2000

  • cycle_second_step_size (int) – Number of training iterations in the decreasing half of a cycle. If cycle_second_step_size is None, it is set to cycle_first_step_size. Default: None

  • cycle_first_stair_count (int) – Number of stairs in first half of cycle phase. This means

  • 0 (lr/mom are changed in staircase fashion. Default) –

  • disabled. (means staircase) –

  • cycle_second_stair_count (int) – Number of stairs in second half of cycle phase. This means

  • 0

  • disabled.

  • decay_step_size (int) – Intervals for applying decay in decay phase. Default: 0, means no decay.

  • cycle_momentum (bool) – If True, momentum is cycled inversely to learning rate between ‘cycle_min_mom’ and ‘cycle_max_mom’. Default: True

  • cycle_min_mom (float or list) – Initial momentum which is the lower boundary in the cycle for each parameter group. Default: 0.8

  • cycle_max_mom (float or list) – Upper momentum boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (cycle_max_mom - cycle_min_mom). The momentum at any cycle is the difference of cycle_max_mom and some scaling of the amplitude; therefore cycle_min_mom may not actually be reached depending on scaling function. Default: 0.9

  • decay_mom_rate (float) – Decay rate for momentum. Default: 0.

  • last_batch_iteration (int) – The index of the last batch. This parameter is used when resuming a training job. Since step() should be invoked after each batch instead of after each epoch, this number represents the total number of batches computed, not the total number of epochs computed. When last_batch_iteration=-1, the schedule is started from the beginning. Default: -1

Example

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>>     for batch in data_loader:
>>>         train_batch(...)
>>>         scheduler.step()

WarmupLR

class deepspeed.runtime.lr_schedules.WarmupLR(optimizer: Optimizer, warmup_min_lr: float = 0.0, warmup_max_lr: float = 0.001, warmup_num_steps: int = 1000, warmup_type: str = 'log', last_batch_iteration: int = -1)[source]

Increase the learning rate of each parameter group from min lr to max lr over warmup_num_steps steps, and then fix at max lr.

Parameters
  • optimizer (Optimizer) – Wrapped optimizer.

  • warmup_min_lr (float or list) – minimum learning rate. Default: 0

  • warmup_max_lr (float or list) – maximum learning rate. Default: 0.001

  • warmup_num_steps (int) – number of steps to warm up from min_lr to max_lr. Default: 1000

  • {‘log’ (warmup_type) – increasing function from min_lr to max_lr during warmup. Default: log

  • ‘linear’} – increasing function from min_lr to max_lr during warmup. Default: log

  • last_batch_iteration (int) – The index of the last batch. Default: -1.

Example

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = WarmupLR(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>>     for batch in data_loader:
>>>         train_batch(...)
>>>         scheduler.step()

WarmupDecayLR

class deepspeed.runtime.lr_schedules.WarmupDecayLR(optimizer: Optimizer, total_num_steps: int, warmup_min_lr: float = 0.0, warmup_max_lr: float = 0.001, warmup_num_steps: int = 1000, warmup_type: str = 'log', last_batch_iteration: int = -1)[source]

Increase the learning rate of each parameter group from min lr to max lr over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.

Parameters
  • optimizer (Optimizer) – Wrapped optimizer.

  • total_num_steps (int) – total number of training steps

  • warmup_min_lr (float or list) – minimum learning rate. Default: 0

  • warmup_max_lr (float or list) – maximum learning rate. Default: 0.001

  • warmup_num_steps (int) – number of steps to warm up from min_lr to max_lr. Default: 1000

  • {‘log’ (warmup_type) – increasing function from min_lr to max_lr during warmup. Default: log

  • ‘linear’} – increasing function from min_lr to max_lr during warmup. Default: log

  • last_batch_iteration (int) – The index of the last batch. Default: -1.

Example

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = WarmupDecayLR(optimizer, 1000000)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>>     for batch in data_loader:
>>>         train_batch(...)
>>>         scheduler.step()

WarmupCosineLR

class deepspeed.runtime.lr_schedules.WarmupCosineLR(optimizer: Optimizer, total_num_steps: int, warmup_min_ratio: float = 0.0, warmup_num_steps: int = 1000, cos_min_ratio: float = 0.0001, warmup_type: str = 'log', last_batch_iteration: int = -1)[source]

Increase the learning rate of each parameter group from min lr ratio to max lr ratio over warmup_num_steps steps, and then decay at cosine rate over the remaining training steps to min cosine ratio.

Parameters
  • optimizer (Optimizer) – Wrapped optimizer.

  • total_num_steps (int) – total number of training steps

  • warmup_min_ratio (float or list) – warmup start learning rate ratio. Default: 0

  • warmup_num_steps (int) – number of steps to warm up from warmup_min_ratio to 1.0. Default: 1000

  • {‘log’ (warmup_type) – increasing function from min_lr to max_lr during warmup. Default: log

  • ‘linear’} – increasing function from min_lr to max_lr during warmup. Default: log

  • cos_min_ratio (float) – cosine end learning rate ratio. Default: 0.0001

  • last_batch_iteration (int) – The index of the last batch. Default: -1.

Example

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = WarmupCosineLR(optimizer, 1000000)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>>     for batch in data_loader:
>>>         train_batch(...)
>>>         scheduler.step()