ZeRO-3 Offload

The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning the three model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them. By doing this, it boosts memory efficiency compared to classic data-parallelism while retaining its computational granularity and communication efficiency.

ZeRO-Offload further increases memory efficiency by offloading the optimizer’s states and computations to the CPU. The model parameters can also be offloaded for even more memory savings!

For more information on our algorithms, please see our papers on ZeRO and ZeRO-Offload.

Getting Started

If you are new to DeepSpeed, check out our Getting Started page.

Once you are training with DeepSpeed, enabling ZeRO-3 Offload is as simple as enabling it in your DeepSpeed configuration! Below are a few examples of ZeRO-3 configurations. Please see our config guide for a complete list of options for configuration and performance tuning.

Note

ZeRO-3 Offload works best with our heavily optimized deepspeed.ops.adam.DeepSpeedCPUAdam optimizer. We recommend using our optimizer config to instruct deepspeed.initialize() to build the optimizer for you.

Example ZeRO-3 Offload Configurations

  1. Use ZeRO to partition the optimizer states (stage 1), gradients (stage 2), and parameters (stage 3).

    {
        "zero_optimization": {
            "stage": 3,
            "overlap_comm": true
        },
        "fp16": {
            "enabled": true
        },
        "optimizer": {
            "type": "AdamW",
            "params": {
            "lr": 0.001,
            "betas": [
                0.8,
                0.999
            ],
            "eps": 1e-8,
            "weight_decay": 3e-7
            }
        },
        ...
    }
    
  2. Additionally offload the optimizer states and computations to the CPU.

    {
        "zero_optimization": {
            "stage": 3,
            "cpu_offload": true,
            "overlap_comm": true
        },
        ...
    }
    
  3. Save even more memory by offloading parameters to the CPU memory.

    {
        "zero_optimization": {
            "stage": 3,
            "cpu_offload": true,
            "cpu_offload_params": true,
            "overlap_comm": true
        },
        ...
    }
    

Assumptions

DeepSpeed automatically coordinates the collection (i.e., all-gather), partitioning (i.e., scatter), and offloading of parameters at the granularity of (sub)module forward() methods. The backward pass is handled similarly. This strategy has two underlying assumptions:

  1. The forward and backward passes of submodules must individually fit in device memory.
  2. A module’s parameters are only accessed within its own __init__ and forward() methods. Otherwise, DeepSpeed must be instructed to collect and re-partition the parameter. See Manual Parameter Coordination for manually coordinating parameters.

Constructing Massive Models

ZeRO-3 enables massive models whose parameters exceed the size of individual nodes in a system. For the typical case of training without model parallelism, you can simply allocate your model in our context:

with deepspeed.zero.Init():
    model = MyLargeModel()
class deepspeed.zero.Init(module=None, data_parallel_group=None, mem_efficient_linear=True, remote_device=None, pin_memory=False, enabled=True)

A context to enable massive model construction for training with ZeRO-3. Models are automatically partitioned (or, sharded) across the system and converted to half precision.

Parameters:
  • module (torch.nn.Module, optional) – If provided, partition the model as if it was constructed in the context.
  • data_parallel_group (torch.distributed process group, optional) – The group of processes to partition among. Defaults to all processes.
  • mem_efficient_linear (bool, optional) – Replace torch.nn.functional.linear with an implementation that allows DeepSpeed to partition parameters. Defaults to True.
  • remote_device (string, optional) – The device to store model weights. Passing "cpu" will create the model in CPU memory. The model may still be moved to GPU if cpu_offload_param is False in the config provided to deepspeed.initialize(). Defaults to the local GPU.
  • pin_memory (bool, optional) – Potentially increase performance by using pinned memory for model weights. remote_device must be "cpu". Defaults to False.
  • enabled (bool, optional) – If False, this context has no effect. Defaults to True.

This context accelerates model initialization and enables models that are too large to allocate in their entirety in CPU memory. It has the following effects:

  1. allocates tensors to either GPU or CPU memory
  2. converts floating point tensors to half precision
  3. immediately partitions tensors among the group of data-parallel devices
  4. (optional) replaces torch.nn.functional.linear with a more memory-efficient implementation

These modifications allow for models that exceed the size of local CPU/GPU memory, but fit within the total system memory (i.e., aggregate CPU or GPU memory) across all nodes. Consider initializing a model with one trillion parameters, whose weights occupy two terabytes (TB) in half precision. The initial CPU allocation in full precision requires 4TB of memory per process, and so a system with 8 GPUs per node would need 32TB of CPU memory due to data-parallel redundancies. Instead, by immediately partitioning tensors we remove the redundancies. The result is that regardless of the number of GPUs, we still only require the original 4TB. This allows for a linear increase in model size with the aggregate system memory. For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion parameter model with 4 nodes and 32 GPUs.

Important: If the fp16 weights of the model can’t fit onto a single GPU memory this feature must be used.

Note

Initializes torch.distributed if it has not already been done so. See deepseed.init_distributed() for more information.

Note

Can also be used as a decorator:

@deepspeed.zero.Init()
def get_model():
    return MyLargeModel()

Note

Only applicable to training with ZeRO-3.

Examples

  1. Allocate a model and partition it among all processes:

    with deepspeed.zero.Init():
        model = MyLargeModel()
    
  2. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:

    with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
                             remote_device="cpu",
                             pin_memory=True):
        model = MyLargeModel()
    
  3. Partition an already-allocated model in CPU memory:

    model = deepspeed.zero.Init(module=model)
    

Manual Parameter Coordination

Most models require no modification to be trained with ZeRO-3. However, in some cases one may need to access model weights outside of the training loop, or to share weights across submodules during training. DeepSpeed has several mechanisms to coordinate partitioned weights for ZeRO-3.

Gathering Parameters

DeepSpeed provides mechanisms for collecting (or gathering) a partitioned parameter.

Some models partitioned with deepspeed.zero.Init may need to access a module’s weights outside of the class constructor or its forward() method. We refer to these weights as external parameters, since these parameters are accessed outside of the module that created them. To do so, use deepspeed.zero.GatheredParameters or deepspeed.zero.register_external_parameter().

class deepspeed.zero.GatheredParameters(params, modifier_rank=None, fwd_module=None, enabled=True)

A context that collects parameters that were partitioned via a deepspeed.zero.Init context. The parameters are partitioned again upon exit.

Parameters:
  • params (torch.nn.Parameter) – A single parameter or a list of parameters to collect. It’s assumed that all parameters are zero params.
  • modifier_rank (int, optional) – If specified, this rank’s parameter will be broadcasted on exit from the context. This argument is required if params are modified, so that all processes have a consistent view of the data. Defaults to None.
  • fwd_module (torch.nn.Module, optional) – If specified, params will be registered as external parameters of fwd_module. See deepspeed.zero.register_external_parameter().
  • enabled (bool, optional) – If False, this context is a no-op. Defaults to True.

Examples

  1. Allocate a partitioned module, initialize its weight on rank 0, and update all processes.

    with deepspeed.zero.Init():
        linear = torch.nn.Linear(1000,1000)
    
    with deepspeed.zero.GatheredParameters(linear.weight,
                                           modifier_rank=0):
        if torch.distributed.get_rank() == 0:
            linear.weight.zero_()
    
  2. Collect a partitioned weight to pass to another module during training. The parameter will be registered as an external parameter and made available during the backward pass.

    def forward(self, input):
        x = self.layer1(input)
    
        # self.layer1.weight is required by self.layer2.forward
        with deepspeed.zero.GatheredParameters(self.layer1.weight,
                                               fwd_module=self):
            y = self.layer2(x, self.layer1.weight)
        return y
    
  3. Pretrained model loading

    with deepspeed.zero.Init():
        model = MyModel()
    
    state_dict = torch.load(model_path, map_location="cpu")
    
    def load(module: nn.Module, prefix=""):
        # because zero3 puts placeholders in model params, this context
        # manager gathers (unpartitions) the params of the current layer, then loads from
        # the state dict and then re-partitions them again
        with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
            if torch.distributed.get_rank() == 0:
                module._load_from_state_dict(state_dict, prefix)
    
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + ".")
    
    load(model, prefix="")
    

If this approach is not used, then the full model will first get copied to each GPU. For models bigger than the memory of a single gpu this method is required.

Registering External Parameters

Consider the following pattern common in language models such as GPT:

class LanguageModel(torch.nn.Module):
    ...
    def forward(self, inputs):
        embeds = self.embeddings(inputs)
        ...
        logits = compute_logits(output, self.embeddings.weight)
        ...

The tensor embeddings.weight is used in both embeddings.forward() and compute_logits(). We call embeddings.weight an external parameter because it is used in the training loop outside of its owning module’s forward pass. DeepSpeed will coordinate external parameters if they are registered prior to the first forward pass.

deepspeed.zero.register_external_parameter(module, parameter)

Instruct DeepSpeed to coordinate parameter’s collection and partitioning in the forward and backward passes of module.

This is used when a parameter is accessed outside of its owning module’s forward(). DeepSpeed must know to collect it from its partitioned state and when to release the memory.

Note

This is only applicable to training with ZeRO stage 3.

Parameters:
  • module (torch.nn.Module) – The module that requires parameter in its forward pass.
  • parameter (torch.nn.Parameter) – The parameter to register.
Raises:

RuntimeError – If parameter is not of type torch.nn.Parameter.

Examples

  1. Register a weight that is used in another module’s forward pass (line 6). Parameter layer1.weight is used by layer2 (line 11).

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    class ModuleZ3(torch.nn.Module):
        def __init__(self, *args):
            super().__init__(self, *args)
            self.layer1 = SomeLayer()
            self.layer2 = OtherLayer()
            deepspeed.zero.register_external_parameter(self, self.layer1.weight)
    
        def forward(self, input):
            x = self.layer1(input)
            # self.layer1.weight is required by self.layer2.forward
            y = self.layer2(x, self.layer1.weight)
            return y