ZeRO¶
The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning the three model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them. By doing this, it boosts memory efficiency compared to classic data-parallelism while retaining its computational granularity and communication efficiency.
ZeRO Stage 1: The optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition.
ZeRO Stage 2: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states.
ZeRO Stage 3: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes.
In addition, ZeRO-3 includes the infinity offload engine to form ZeRO-Infinity ([paper](https://arxiv.org/abs/2104.07857)), which can offload all model states to both CPU and NVMe memory for huge memory savings.
For a deep dive of our algorithms, please see our papers on ZeRO, ZeRO-Offload, and ZeRO-Infinity.
Note
DeepSpeed first included offloading capabilities with ZeRO-Offload, a system for offloading optimizer and gradient states to CPU memory within ZeRO-2. ZeRO-Infinity is the next generation of offloading capabilities, accessible to ZeRO-3. ZeRO-Infinity has all of the savings of ZeRO-Offload, plus is able to offload more the model weights and has more effective bandwidth utilization and overlapping of computation and communication.
Getting Started¶
If you are new to DeepSpeed, check out our Getting Started page.
Once you are training with DeepSpeed, enabling ZeRO-3 offload is as simple as enabling it in your DeepSpeed configuration! Below are a few examples of ZeRO-3 configurations. Please see our config guide for a complete list of options for configuration and performance tuning.
Note
ZeRO-Infinity and ZeRO-Offload work best with our heavily optimized
deepspeed.ops.adam.DeepSpeedCPUAdam
optimizer. We recommend using
our optimizer config
to instruct deepspeed.initialize()
to build the optimizer for you.
ZeRO Configurations¶
All the settings for DeepSpeed ZeRO are set with the DeepSpeedZeroConfig.
The dictionary provided under the zero_optimization
entry of the main
DeepSpeed configuration dict will be parsed and validated with this class.
Sub-configurations for parameter offload and optimizer offload settings are
parsed by DeepSpeedZeroOffloadParamConfig and
DeepSpeedZeroOffloadOptimizerConfig.
- class deepspeed.runtime.zero.config.DeepSpeedZeroConfig[source]¶
Sets parameters for ZeRO optimizations.
- stage: ZeroStageEnum = 0¶
Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning, respectively.
- contiguous_gradients: bool = True¶
Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass.
- reduce_scatter: bool = True¶
Uses reduce or reduce scatter instead of allreduce to average gradients
- reduce_bucket_size: int = 500,000,000¶
Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes
- Constraints
minimum = 0
- allgather_partitions: bool = True¶
Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step
- allgather_bucket_size: int = 500,000,000¶
Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes
- Constraints
minimum = 0
- overlap_comm: bool = None¶
Attempts to overlap the reduction of the gradients with backward computation
- load_from_fp32_weights: bool = True¶
Boolean indicating whether to initialize fp32 master weights from fp32 copies in checkpoint (no precision loss) or from model’s fp16 copies (with precision loss). This can be used to initialize optimizer state even when checkpoint is missing optimizer state.
- elastic_checkpoint: bool = False¶
Enable loading checkpoint that was saved by job with different GPU count. No longer supported.
- offload_param: Optional[DeepSpeedZeroOffloadParamConfig] = None¶
Enable offloading of model parameters to CPU or NVMe. This frees up GPU memory for larger models or batch sizes. Valid only with stage 3. Expects a dictionary containing values for
DeepSpeedZeroOffloadParamConfig
.
- offload_optimizer: Optional[DeepSpeedZeroOffloadOptimizerConfig] = None¶
Enable offloading of optimizer state to CPU or NVMe, and optimizer computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid for ZeRO stage 1, 2, 3. Expects a dictionary containing values for
DeepSpeedZeroOffloadOptimizerConfig
.
- sub_group_size: int = 1,000,000,000¶
Tile size for parameter processing to fit massive models (with trillions of parameters). Used by ZeRO3-Offload and ZeRO-Infinity
- Constraints
minimum = 0
- cpu_offload_param: bool = None¶
Deprecated, please use
offload_param
- cpu_offload_use_pin_memory: bool = None¶
Deprecated, please use
offload_param
oroffload_optimizer
- cpu_offload: bool = None¶
Deprecated, please use
offload_optimizer
- prefetch_bucket_size: int = 50,000,000 (alias 'stage3_prefetch_bucket_size')¶
Maximum number of parameter elements to fetch ahead of use. Used by ZeRO3, ZeRO3-Offload, ZeRO-Infinity, and ZeRO-Inference.
- Constraints
minimum = 0
- param_persistence_threshold: int = 100,000 (alias 'stage3_param_persistence_threshold')¶
Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages).
- Constraints
minimum = 0
- model_persistence_threshold: int = sys.maxsize (alias 'stage3_model_persistence_threshold')¶
Maximum number of parameter elements that can be persisted in GPU and not partitioned. This imposes an upper bound on the number of unpartitioned parameters resulting from param_persistence_threshold setting. Used by ZeRO3-Offload, ZeRO-Infinity and ZeRO-Inference.
- Constraints
minimum = 0
- max_live_parameters: int = 1,000,000,000 (alias 'stage3_max_live_parameters')¶
The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but perform more communication.
- Constraints
minimum = 0
- max_reuse_distance: int = 1,000,000,000 (alias 'stage3_max_reuse_distance')¶
Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less memory, but perform more communication.
- Constraints
minimum = 0
- gather_16bit_weights_on_model_save: bool = False (alias 'stage3_gather_16bit_weights_on_model_save')¶
Consolidate the weights before saving the model by
save_16bit_model()
. Since the weights are partitioned across GPUs, they aren’t part ofstate_dict
, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights.
- stage3_gather_fp16_weights_on_model_save: bool = False¶
Deprecated, please use
gather_16bit_weights_on_model_save
- ignore_unused_parameters: bool = True¶
Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to
False
by default, which means unused parameters are ignored and training continues. Now is just used in stage 2.
- legacy_stage1: bool = False¶
For backward-compatibility enable old ZeRO stage 1 implementation. Use at your own risk, will be deprecated soon.
- round_robin_gradients: bool = False¶
Stage 1 and 2 optimization for CPU offloading that parallelizes gradient copying to CPU memory among ranks by fine-grained gradient partitioning. Performance benefit grows with gradient accumulation steps (more copying between optimizer steps) or GPU count (increased parallelism).
- zero_hpz_partition_size: int = 1¶
Number of ranks in zero parameters partitioning secondary group
- Constraints
minimum = 0
- zero_quantized_weights: bool = False¶
Boolean indicating whether to quantize zero parameters (weights) for efficient all_gather comm
- zero_quantized_nontrainable_weights: bool = False¶
Boolean indicating whether to quantize non-trainable zero parameters (weights) for efficient memory usage and communication. Different from zero_quantized_weights that stores the weights in original precision and only perform quantization during communication, this flag will store the weights in quantized precision. This is useful for LoRA training.
- zero_quantized_gradients: bool = False¶
Boolean indicating whether to use quantized zero gradients for efficient all_2_all_reduce comm
- mics_shard_size: int = -1¶
- mics_hierarchical_params_gather: bool = False¶
- memory_efficient_linear: bool = True¶
Use memory efficient linear implementation, for Stage 3.
- pipeline_loading_checkpoint: bool = False¶
- override_module_apply: bool = True¶
Override nn.Module apply function, for Stage 3.
- class deepspeed.runtime.zero.config.DeepSpeedZeroOffloadParamConfig[source]¶
Set options for parameter offload. Valid only with stage 3.
- device: OffloadDeviceEnum = 'none'¶
Device memory to offload model parameters. Supported options are cpu and nvme.
- nvme_path: Path = None¶
Filesystem path for NVMe device for parameter offloading.
- buffer_count: int = 5¶
Number of buffers in buffer pool for parameter offloading to NVMe.
- Constraints
minimum = 0
- buffer_size: int = 100,000,000¶
Size of buffers in buffer pool for parameter offloading to NVMe.
- Constraints
minimum = 0
- max_in_cpu: int = 1,000,000,000¶
Number of parameter elements to maintain in CPU memory when offloading to NVMe is enabled.
- Constraints
minimum = 0
- pin_memory: bool = False¶
Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead.
- class deepspeed.runtime.zero.config.DeepSpeedZeroOffloadOptimizerConfig[source]¶
Set options for optimizer offload. Valid with stage 1, 2, and 3.
- device: OffloadDeviceEnum = 'none'¶
Device memory to offload optimizer state. Supported options are cpu and nvme. Optimizer computation is offload to CPU regardless of device option.
- nvme_path: Path = None¶
Filesystem path for NVMe device for optimizer state offloading.
- buffer_count: int = 4¶
Number of buffers in buffer pool for optimizer state offloading to NVMe. This should be at least the number of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance).
- Constraints
minimum = 0
- pin_memory: bool = False¶
Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead.
- pipeline_read: bool = False¶
For tile-based optimizer step processing, overlap read of next tile with computation of current tile. Used in ZeRO-Infinity.
- pipeline_write: bool = False¶
For tile-based optimizer step processing, overlap write of previous tile with computation of current tile.
- fast_init: bool = False¶
Enable fast optimizer initialization when offloading to NVMe.
Example ZeRO-3 Configurations¶
Use ZeRO to partition the optimizer states (stage 1), gradients (stage 2), and parameters (stage 3).
{ "zero_optimization": { "stage": 3, }, "fp16": { "enabled": true }, "optimizer": { "type": "AdamW", "params": { "lr": 0.001, "betas": [ 0.8, 0.999 ], "eps": 1e-8, "weight_decay": 3e-7 } }, ... }
Additionally offload the optimizer states and computations to the CPU with ZeRO-Infinity.
{ "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu" } }, ... }
Save even more memory by offloading parameters to the CPU memory.
{ "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu" } "offload_param": { "device": "cpu" } }, ... }
Save even MORE memory by offloading to NVMe (if available on your system):
{ "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "nvme", "nvme_path": "/nvme_data" } "offload_param": { "device": "nvme", "nvme_path": "/nvme_data" } }, ... }
MiCS Configurations¶
All MiCS configurations are set with DeepSpeedZeroConfig. MiCS assumes ZeRO stage 3 optimization is enabled. For now, there are two configuration fields of MiCS mics_shard_size and mics_hierarchical_params_gather. mics_shard_size controls how many devices are used for partitioning the model states. mics_hierarchical_params_gather controls whether we use a two-stage hierarchical way to gather parameters in the forward computation. mics_hierarchical_params_gather is useful when model states are partitioned across multiple nodes and the cross-node bandwidth is slow. By default this is turned off.
Example MiCS Configurations¶
Use MiCS to partition the model states (including optimizer states, gradients, and parameters). The following config example partitions the model states to eight devices, and assumes the eight devices are located within a single node (mics_hierarchical_params_gather is False).
{ "zero_optimization": { "stage": 3, "mics_shard_size": 8, "mics_hierarchical_params_gather": False, }, ... }
Assumptions¶
DeepSpeed automatically coordinates the collection (i.e., all-gather),
partitioning (i.e., scatter), and offloading of parameters at the
granularity of (sub)module forward()
methods. The backward pass is
handled similarly. This strategy has two underlying assumptions:
The forward and backward passes of submodules must individually fit in device memory. If this not the case,
deepspeed.zero.TiledLinear
implements memory-centric tiling and works with ZeRO-3 to break linear layers into a sequence of smaller submodules that can fit in memory.A module’s parameters are only accessed within its own
__init__
andforward()
methods. Otherwise, DeepSpeed must be instructed to collect and re-partition the parameter. See Manual Parameter Coordination for manually coordinating parameters.
Constructing Massive Models¶
ZeRO-3 enables massive models whose parameters exceed the size of individual nodes in a system. For the typical case of training without model parallelism, you can simply allocate your model in our context:
with deepspeed.zero.Init():
model = MyLargeModel()
- class deepspeed.zero.Init(module=None, data_parallel_group=None, mem_efficient_linear=True, remote_device=None, pin_memory=False, config_dict_or_path=None, config=None, enabled=True, dtype=None, mpu=None, zero_param_parallel_group=None, zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, sequence_data_parallel_group=None, param_swapper=None)¶
- get_partition_rank()¶
subclass can overload to specify different relative rank in parameter partition group
- get_dp_process_group()¶
Return the communication group with all data-parallel ranks
Manual Parameter Coordination¶
Most models require no modification to be trained with ZeRO-3. However, in some cases one may need to access model weights outside of the training loop, or to share weights across submodules during training. DeepSpeed has several mechanisms to coordinate partitioned weights for ZeRO-3.
Gathering Parameters¶
DeepSpeed provides mechanisms for collecting (or gathering) a partitioned parameter.
Some models partitioned with deepspeed.zero.Init
may need to access
a module’s weights outside of the class constructor or its forward()
method. We refer to these weights as external parameters, since these
parameters are accessed outside of the module that created them. To do so, use
deepspeed.zero.GatheredParameters
or deepspeed.zero.register_external_parameter()
.
- class deepspeed.zero.GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True)¶
Registering External Parameters¶
ZeRO-3 will automatically collect and partition the model parameters as they are needed during the forward and backward passes. However, in some cases a parameter may be used outside of its module’s forward pass. We call these external parameters. ZeRO-3 can coordinate these parameters if they are registered either automatically or manually.
Note
DeepSpeed version 0.3.15
includes automatic external parameter
discovery and registration to support the most common cases. Parameters
can still be manually registered if they cannot be automatically
detected.
DeepSpeed can automatically detect the following external parameter scenarios:
Parameter access: consider the following pattern common in language models such as GPT:
The tensor
embeddings.weight
is used in bothembeddings.forward()
andcompute_logits()
. We callembeddings.weight
an external parameter because it is used in the training loop outside of its owning module’s forward pass.class LanguageModel(torch.nn.Module): ... def forward(self, inputs): embeds = self.embeddings(inputs) ... logits = compute_logits(output, self.embeddings.weight) ...
Returning a parameter:
CustomLinear
returns both an output and its ownbias
parameter. DeepSpeed will detect the externalbias
parameter and register it with submodules that useCustomLinear
.class CustomLinear(torch.nn.Linear): def forward(self, *input): output = super().forward(*input) return output, self.bias
- deepspeed.zero.register_external_parameter(module, parameter)¶
Instruct DeepSpeed to coordinate
parameter
’s collection and partitioning in the forward and backward passes ofmodule
.This is used when a parameter is accessed outside of its owning module’s
forward()
. DeepSpeed must know to collect it from its partitioned state and when to release the memory.Note
This is only applicable to training with ZeRO stage 3.
- Parameters
module (
torch.nn.Module
) – The module that requiresparameter
in its forward pass.parameter (
torch.nn.Parameter
) – The parameter to register.
- Raises
RuntimeError – If
parameter
is not of typetorch.nn.Parameter
.
Examples
Register a weight that is used in another module’s forward pass (line 6). Parameter
layer1.weight
is used bylayer2
(line 11).1class ModuleZ3(torch.nn.Module): 2 def __init__(self, *args): 3 super().__init__(self, *args) 4 self.layer1 = SomeLayer() 5 self.layer2 = OtherLayer() 6 deepspeed.zero.register_external_parameter(self, self.layer1.weight) 7 8 def forward(self, input): 9 x = self.layer1(input) 10 # self.layer1.weight is required by self.layer2.forward 11 y = self.layer2(x, self.layer1.weight) 12 return y
Overriding Module.apply¶
A convenient mechanism for customizing model initialization is Module.apply.
With ZeRO stage 3, Module.apply
implementations must account for parameter partitioning by zero.Init
during model initialization. The default behavior of ZeRO stage 3 is to automatically
handle this issue by overriding Module.apply
to ensure that parameters are gathered before access by Module.apply
. The benefit of this approach is development convenience, since
users are saved the burden of manual parameter coordination in Module.apply
. However, the downside is slow model initialization, since all the model parameters (e.g., billions) are gathered
even though the common usage of Module.apply
is to customize a few parameters. Developers can disable this default behavior by setting the override_module_apply
configuration knob to False
,
for faster model initialization at the cost of manually handling partitioned parameters in their Module.apply
implementations.
Memory-Centric Tiling¶
To reduce the working memory requirements of DL training for large models, ZeRO-Infinity includes technique called memory-centric tiling that exploits the data fetch and release pattern of ZeRO-3 to reduce the working memory requirements by breaking down a large operator into smaller tiles that can be executed sequentially. When combined with ZeRO-3, the parameter and gradients of each tile can be fetched and released one at a time, reducing the working memory proportional to the number of tiles. Therefore, ZeRO-Infinity can support operators of arbitrary sizes, without refactoring for model parallelism to fit them in limited GPU memory.
- class deepspeed.zero.TiledLinear(in_features, out_features, bias=True, in_splits=1, out_splits=1, input_is_already_split=False, combine_out_splits=True, linear_cls=<class 'torch.nn.modules.linear.Linear'>, init_linear=None, **kwargs)¶
- forward(input_)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- copy_params_from(other)¶
Copy the weight and bias data from
other
.This is especially useful for reproducible initialization and testing.
Equivalent to:
with torch.no_grad(): self.weight.copy_(other.weight) if self.bias is not None: self.bias.copy_(other.bias)
Note
If ZeRO-3 is enabled, this is a collective operation and the updated parameters of data-parallel rank 0 will be visible on all ranks. See
deepspeed.zero.GatheredParameters
for more information.- Parameters
other (
torch.nn.Linear
) – the linear layer to copy from.
Debugging¶
Debugging ZeRO training is complicated by the partitioning of parameters, gradients, and optimizer states. None of these 3 groups of tensors (model states) can be normally accessed because of that. To overcome that DeepSpeed provides the following routines for accessing individual model states in their unpartitioned form.
Important: Please note that these utilities must be called by all processes participating in the training, even if you decide to do something with the result only in the main process. If all processes don’t participate these utilities will hang waiting for all processes to send their contribution.
Additionally, you must be aware that these routines return correct data only in specific phases of the training. So for examples the gradients are valid after backward
and before step
. The optimizer states are updated after step
. Same goes for fp32 master weights.
- deepspeed.utils.safe_get_full_fp32_param(param)[source]¶
Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
- Parameters
param (
torch.nn.Parameter
) – A model parameter
- deepspeed.utils.safe_get_full_grad(param)[source]¶
Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
- Parameters
param (
torch.nn.Parameter
) – A model parameter
- deepspeed.utils.safe_get_full_optimizer_state(param, optim_state_key)[source]¶
Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
- Parameters
param (
torch.nn.Parameter
) – A model parameteroptim_state_key (
string
) – Key value of optimizer state (e.g., exp_avg in Adam optimizer)
These routines can be used in a training loop as shown in the following snippet.
backward(loss)
[...]
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
for n, lp in model.named_parameters():
# 1. gradient lookup
# For zero1 and zero2, gradient lookup must be called after `backward` and before `step`
# For zero3, gradient lookup must be called after `backward`
hp_grad = safe_get_full_grad(lp)
# 2. fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
hp = safe_get_full_fp32_param(lp)
exp_avg = safe_get_full_optimizer_state(lp, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, "exp_avg_sq")
[...]
optimizer.step()
Modifying Partitioned States¶
Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following two routines for modifying the fp32 master parameters and the fp32 optimizer states.
- deepspeed.utils.safe_set_full_fp32_param(param, value)[source]¶
Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.
- Parameters
param (
torch.nn.Parameter
) – A model parametervalue (
torch.Tensor
) – New value
- deepspeed.utils.safe_set_full_optimizer_state(param, value, optim_state_key)[source]¶
Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.
- Parameters
param (
torch.nn.Parameter
) – A model parametervalue (
torch.Tensor
) – New valueoptim_state_key (
string
) – Key value of optimizer state (e.g., exp_avg in Adam optimizer)
These routines can be used at any point after initialization of the DeepSpeed engine (i.e., deepspeed.initialize()
) as shown in the following snippet.
[...]
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
# Here is an example to zero all the fp32 parameters and optimizer states.
for n, lp in model.named_parameters():
# Assume zero stage 1 or 2, since stage 3 requires a gather to assemble lp
zero_tensor = torch.zeros_like(lp)
hp = safe_set_full_fp32_param(lp, zero_tensor)
exp_avg = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")
[...]
GPU Memory Management¶
By default at the end of training with ZeRO stage 3 some parameters could remain unpartitioned and use up some gpu memory.
This is done on purpose as an optimization should you resume training again. If you’d like to clear out the cached
parameters that use up gpu memory, you can call empty_partition_cache
method of a DeepSpeed engine.
The following code snippet illustrates this functionality.
with zero.Init():
model = MyLargeModel()
ds_engine, _, _, _ = deepspeed.initialize(model, ...)
for batch in ...:
loss = ds_engine(batch)
ds_engine.backward(batch)
ds_engine.step()
# Free GPU memory consumed by model parameters
ds_engine.empty_partition_cache()