# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from cpuinfo import get_cpu_info
from deepspeed.utils import logger
from deepspeed.utils.logging import should_log_le
from deepspeed.ops.op_builder import CPUAdamBuilder
[docs]class DeepSpeedCPUAdam(torch.optim.Optimizer):
optimizer_id = 0
def __init__(self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adamw_mode=True,
fp32_optimizer_states=True):
"""Fast vectorized implementation of two variations of Adam optimizer on CPU:
* Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
* AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101)
DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W).
In order to apply this optimizer, the model requires to have its master parameter (in FP32)
reside on the CPU memory.
To train on a heterogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
with minimal impact on training throughput. DeepSpeedCPUAdam plays an important role to minimize
the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
(https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
.. note::
We recommend using our `config
<https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
to allow :meth:`deepspeed.initialize` to build this optimizer
for you.
Arguments:
model_params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
adamw_mode: select between Adam and AdamW implementations (default: AdamW)
fp32_optimizer_states: creates momentum and variance in full precision regardless of
the precision of the parameters. Set to False to keep optimizer states
in the parameter dtype (e.g. bf16), which reduces the optimizer-state
memory footprint at the cost of lower state precision. (default: True)
"""
default_args = dict(lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
bias_correction=bias_correction,
amsgrad=amsgrad)
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)
cpu_info = get_cpu_info()
self.cpu_vendor = cpu_info["vendor_id_raw"].lower() if "vendor_id_raw" in cpu_info else "unknown"
if "amd" in self.cpu_vendor:
for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):
if p.dtype == torch.half:
logger.warning("FP16 params for CPUAdam may not work on AMD CPUs")
break
else:
continue
break
self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
self.adam_w_mode = adamw_mode
self.fp32_optimizer_states = fp32_optimizer_states
self.ds_opt_adam = CPUAdamBuilder().load()
self.ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode,
should_log_le("info"))
def __del__(self):
# need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
# is used multiple times in the same process (notebook or pytest worker)
self.ds_opt_adam.destroy_adam(self.opt_id)
def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Update the model parameters.
.. note::
This method will be called internally by ZeRO-Offload. DeepSpeed
users should still use ``engine.step()`` as shown in the
`Getting Started
<https://www.deepspeed.ai/getting-started/#training>`_ guide.
Args:
closure (callable, optional): closure to compute the loss.
Defaults to ``None``.
Returns:
loss: if ``closure`` is provided. Otherwise ``None``.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# intended device for step
device = torch.device('cpu')
for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):
if p.grad is None:
continue
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
state = self.state[p]
# State initialization
if len(state) == 0:
#print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
#use full precision by default unless self.fp32_optimizer_states is off
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
#memory_format=torch.preserve_format)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
#memory_format=torch.preserve_format)
state['step'] += 1
beta1, beta2 = group['betas']
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'])
return loss
@torch.no_grad()
def step_subgroup(self, subgroup_id: int, closure=None):
"""Update the model parameters in a single subgroup (by index)."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# Intended device for step
device = torch.device('cpu')
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
state = self.state[subgroup_id]
if len(state) == 0:
state['step'] = 0
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
state['step'] += 1
beta1, beta2 = group['betas']
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'])
return loss
@torch.no_grad()
def rollback_subgroup(self, sub_group_id: int, closure=None):
"""
Rollback the optimizer state for a specific subgroup.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# Intended device for step
device = torch.device('cpu')
# Validate subgroup state exists and is initialized
if sub_group_id not in self.state or len(self.state[sub_group_id]) == 0:
raise RuntimeError(f"Cannot rollback optimizer state for sub_group_id {sub_group_id} "
f"as it has not been initialized.")
subgroup_state = self.state[sub_group_id]
# Check if we can rollback (step count must be > 0)
if subgroup_state.get('step', 0) <= 0:
raise RuntimeError(f"Cannot rollback sub_group_id {sub_group_id}: "
f"step count is {subgroup_state.get('step', 0)}")
for _, group in enumerate(self.param_groups):
for _, param in enumerate(group['params']):
if param.grad is None:
continue
assert param.device == device, (
f"CPUAdam param is on {param.device} and must be 'cpu', "
f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.")
beta1, beta2 = group['betas']
self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2,
group['eps'], group['weight_decay'], group['bias_correction'],
param.data, param.grad.data, subgroup_state['exp_avg'],
subgroup_state['exp_avg_sq'])
subgroup_state['step'] -= 1
return loss