Optimizers

DeepSpeed offers high-performance implementations of Adam optimizer on CPU; FusedAdam, FusedLamb, OnebitAdam, OnebitLamb optimizers on GPU.

Adam (CPU)

class deepspeed.ops.adam.DeepSpeedCPUAdam(model_params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, adamw_mode=True, fp32_optimizer_states=True)[source]

FusedAdam (GPU)

class deepspeed.ops.adam.FusedAdam(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, adam_w_mode=True, weight_decay=0.0, amsgrad=False, set_grad_none=True)[source]

Implements Adam algorithm.

Currently GPU-only. Requires Apex to be installed via pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./.

This version of fused Adam implements 2 fusions.

  • Fusion of the Adam update’s elementwise operations

  • A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.

apex.optimizers.FusedAdam may be used as a drop-in replacement for torch.optim.AdamW, or torch.optim.Adam with adam_w_mode=False:

opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
...
opt.step()

apex.optimizers.FusedAdam may be used with or without Amp. If you wish to use FusedAdam with Amp, you may choose any opt_level:

opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()

In general, opt_level="O1" is recommended.

Warning

A previous version of FusedAdam allowed a number of additional arguments to step. These additional arguments are now deprecated and unnecessary.

Adam was been proposed in Adam: A Method for Stochastic Optimization.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False) NOT SUPPORTED in FusedAdam!

  • adam_w_mode (boolean, optional) – Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True)

  • set_grad_none (bool, optional) – whether set grad to None when zero_grad() method is called. (default: True)

FusedLamb (GPU)

class deepspeed.ops.lamb.FusedLamb(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, eps_inside_sqrt=False, weight_decay=0.0, max_grad_norm=0.0, max_coeff=10.0, min_coeff=0.01, amsgrad=False)[source]

Implements the LAMB algorithm. Currently GPU-only.

LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. https://arxiv.org/abs/1904.00962

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • bias_correction (bool, optional) – bias correction (default: True)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • eps_inside_sqrt (boolean, optional) – in the ‘update parameters’ step, adds eps to the bias-corrected second moment estimate before evaluating square root instead of adding it to the square root of second moment estimate as in the original paper. (default: False)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • max_grad_norm (float, optional) – value used to clip global grad norm (default: 0.0)

  • max_coeff (float, optional) – maximum value of the lamb coefficient (default: 10.0)

  • min_coeff (float, optional) – minimum value of the lamb coefficient (default: 0.01)

  • amsgrad (boolean, optional) – NOT SUPPORTED in FusedLamb!

OneBitAdam (GPU)

class deepspeed.runtime.fp16.onebit.adam.OnebitAdam(params, deepspeed=None, lr=0.001, freeze_step=100000, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, eps_inside_sqrt=False, weight_decay=0.0, max_grad_norm=0.0, amsgrad=False, cuda_aware=False, comm_backend_name='nccl')[source]

Implements the 1-bit Adam algorithm. Currently GPU-only. For usage example please see https://www.deepspeed.ai/tutorials/onebit-adam/ For technical details please read https://arxiv.org/abs/2102.02888

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • freeze_step (int, optional) – Number of steps for warmup (uncompressed) stage before we start using compressed communication. (default 100000)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False) NOT SUPPORTED in 1-bit Adam!

  • eps_inside_sqrt (boolean, optional) – in the ‘update parameters’ step, adds eps to the bias-corrected second moment estimate before evaluating square root instead of adding it to the square root of second moment estimate as in the original paper. (default: False)

  • cuda_aware (boolean, required) – Set True if the underlying MPI implementation supports CUDA-Aware communication. (default: False)

  • comm_backend_name (string, optional) – Set to ‘mpi’ if needed. (default: ‘nccl’)

ZeroOneAdam (GPU)

class deepspeed.runtime.fp16.onebit.zoadam.ZeroOneAdam(params, deepspeed=None, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, eps_inside_sqrt=False, weight_decay=0.0, max_grad_norm=0.0, var_freeze_step=100000, var_update_scaler=16, local_step_scaler=32678, local_step_clipper=16, amsgrad=False, cuda_aware=False, comm_backend_name='nccl')[source]

Implements the 0/1 Adam algorithm. Currently GPU-only. For usage example please see https://www.deepspeed.ai/tutorials/zero-one-adam/ For technical details please read https://arxiv.org/abs/2202.06009 :param params: iterable of parameters to optimize or dicts defining

parameter groups.

Parameters
  • lr (float, optional) – learning rate. (default: 1e-3)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • var_freeze_step (int, optional) – The latest step to update the variance, using the notation from https://arxiv.org/abs/2202.06009, it denotes the max{i|i in T_v}. Note that this is different from the freeze step from the 1-bit Adam. The var_freeze_step is usually the end of the learning rate warmup and thus does not require tuning. (default: 100000)

  • var_update_scaler (int, optional) – The interval to update the variance. Note that the update policy for variance follows an exponential rule, where var_update_scaler denotes the kappa in the 0/1 Adam paper. (default: 16)

  • local_step_scaler (int, optional) – The interval to scale the local steps interval according to the learning rate policy. (default: 32678)

  • local_step_clipper (int, optional) – The largest interval for local steps with learning rate policy. This corresponds to the variable H in the 0/1 Adam paper. (default: 16)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False) NOT SUPPORTED in 0/1 Adam!

  • eps_inside_sqrt (boolean, optional) – in the ‘update parameters’ step, adds eps to the bias-corrected second moment estimate before evaluating square root instead of adding it to the square root of second moment estimate as in the original paper. (default: False)

  • cuda_aware (boolean, required) – Set True if the underlying MPI implementation supports CUDA-Aware communication. (default: False)

  • comm_backend_name (string, optional) – Set to ‘mpi’ if needed. (default: ‘nccl’)

OnebitLamb (GPU)

class deepspeed.runtime.fp16.onebit.lamb.OnebitLamb(params, deepspeed=None, lr=0.001, freeze_step=100000, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, eps_inside_sqrt=False, weight_decay=0.0, max_grad_norm=0.0, max_coeff=10.0, min_coeff=0.01, amsgrad=False, cuda_aware=False, comm_backend_name='nccl', coeff_beta=0.9, factor_max=4.0, factor_min=0.5, factor_threshold=0.1)[source]

Implements the 1-bit Lamb algorithm. Currently GPU-only. For usage example please see https://www.deepspeed.ai/tutorials/onebit-lamb/ For technical details please see our paper https://arxiv.org/abs/2104.06069.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • freeze_step (int, optional) – Number of steps for warmup (uncompressed) stage before we start using compressed communication. (default 100000)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • max_coeff (float, optional) – maximum value of the lamb coefficient (default: 10.0)

  • min_coeff (float, optional) – minimum value of the lamb coefficient (default: 0.01)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False) NOT SUPPORTED in 1-bit Lamb!

  • eps_inside_sqrt (boolean, optional) – in the ‘update parameters’ step, adds eps to the bias-corrected second moment estimate before evaluating square root instead of adding it to the square root of second moment estimate as in the original paper. (default: False)

  • cuda_aware (boolean, required) – Set True if the underlying MPI implementation supports CUDA-Aware communication. (default: False)

  • comm_backend_name (string, optional) – Set to ‘mpi’ if needed. (default: ‘nccl’)

  • coeff_beta (float, optional) – coefficient used for computing running averages of lamb coefficient (default: 0.9) note that you may want to increase or decrease this beta depending on the freeze_step you choose, as 1/(1 - coeff_beta) should be smaller than or equal to freeze_step

  • factor_max (float, optional) – maximum value of scaling factor to the frozen lamb coefficient during compression stage (default: 4.0)

  • factor_min (float, optional) – minimum value of scaling factor to the frozen lamb coefficient during compression stage (default: 0.5)

  • factor_threshold (float, optional) – threshold of how much the scaling factor can fluctuate between steps (default: 0.1)