Source code for deepspeed.utils.tensor_fragment

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from dataclasses import dataclass
from deepspeed import comm as dist


@dataclass
class fragment_address:
    numel: int
    start: int


@dataclass
class tensor_fragment:
    lp_fragment: torch.Tensor
    lp_fragment_address: fragment_address
    hp_fragment: torch.Tensor
    hp_fragment_address: fragment_address
    optim_fragment: {}
    gradient_dict: {}
    offload_gradient_dict: {}
    use_offload: bool
    param_group_index: int

    def update_hp(self):
        self.hp_fragment.data.copy_(self.lp_fragment.data)

    def update_lp(self):
        self.lp_fragment.data.copy_(self.hp_fragment.data)

    def get_optim_state_fragment(self, key):
        if key in self.optim_fragment:
            return self.optim_fragment[key]
        else:
            raise ValueError(f'{key} not found in optimizer state fragment')

    def get_hp_fragment_address(self):
        return self.hp_fragment_address

    def get_optim_state_keys(self):
        return list(self.optim_fragment.keys())


def get_full_hp_param(self, optim_state_key=None):
    reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
    if self._hp_mapping is not None:
        lp_frag_address = self._hp_mapping.lp_fragment_address
        reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
        if optim_state_key is None:
            hp_fragment = self._hp_mapping.hp_fragment
        else:
            hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)

        reduce_fragment.data.copy_(hp_fragment.data)
    dist.all_reduce(reduce_buffer, group=self._dp_group)
    return reduce_buffer.reshape_as(self)


def get_full_hp_grad(self):
    reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
    if self._hp_mapping is not None:
        hp_mapping = self._hp_mapping

        if hp_mapping.use_offload:
            gradient_dict = hp_mapping.offload_gradient_dict
        else:
            gradient_dict = hp_mapping.gradient_dict

        if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
            raise ValueError("Gradients are only available immediately after backward and before engine step")

        lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
        hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()

        lp_frag_address = self._hp_mapping.lp_fragment_address
        reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)

        if self.view(-1).shape == hp_grad_fragment.shape:
            reduce_buffer.data.copy_(hp_grad_fragment.data)
        else:
            reduce_fragment.data.copy_(hp_grad_fragment.data)

    dist.all_reduce(reduce_buffer, group=self._dp_group)
    return reduce_buffer.reshape_as(self)


[docs]def safe_get_full_fp32_param(param): """Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter. Args: param (``torch.nn.Parameter``): A model parameter """ # ZeRO stage 3 param if hasattr(param, 'ds_id'): return param._z3_optimizer.get_full_hp_param(param) # ZeRO stage 1, 2, and bf16_optimizer params if hasattr(param, '_hp_mapping'): return param.get_full_hp_param() return None
[docs]def safe_get_full_optimizer_state(param, optim_state_key): """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter. Args: param (``torch.nn.Parameter``): A model parameter """ # ZeRO stage 3 param if hasattr(param, 'ds_id'): return param._z3_optimizer.get_full_hp_param(param, optim_state_key) # ZeRO stage 1, 2, and bf16_optimizer params if hasattr(param, '_hp_mapping'): return param.get_full_hp_param(optim_state_key) return None
# TODO: Figure out the correct return dtype
[docs]def safe_get_full_grad(param): """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter. Args: param (``torch.nn.Parameter``): A model parameter """ if param.grad is not None: return param.grad # ZeRO stage 3 param if hasattr(param, 'ds_id'): return param._z3_optimizer.get_fp32_grad_for_param(param) # ZeRO stage 1, 2, and bf16_optimizer params if hasattr(param, '_hp_mapping'): return param.get_full_hp_grad() return None
def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, param_group_index, partition_start, partition_size, optimizer_state_dict): lp_end = lp_param.numel() + lp_start hp_start = partition_start hp_end = partition_start + partition_size fragment_start = max(lp_start, hp_start) fragment_end = min(lp_end, hp_end) assert fragment_start < fragment_end, \ f'fragment start {fragment_start} should be < fragment_end {fragment_end}' fragment_numel = fragment_end - fragment_start hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel) hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel) optim_fragment = { key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel) for key, value in optimizer_state_dict.items() if torch.is_tensor(value) and value.shape == flat_hp_partition.shape } lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel) lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel) return tensor_fragment(lp_fragment=lp_fragment_tensor, lp_fragment_address=lp_frag_address, hp_fragment=hp_fragment_tensor, hp_fragment_address=hp_frag_address, optim_fragment=optim_fragment, gradient_dict=gradient_dict, offload_gradient_dict=offload_gradient_dict, use_offload=use_offload, param_group_index=param_group_index) ''' Logic for lp_param to hp_param mapping lp lp0 lp1 lp2 lp3 lp4 <------- indices/names lp [ ][ ][ ][ ][ ] <-------- tensors flat_lp [ ] <-------- flat lp params flat_hp [ ] <------------------ flat hp partition on current rank full_hp [ ] <------- full flat hp params lp2 full numel = 16 lp_frag numel = 12 frag_start = 3 frag_end = 15 hp_frag numel = 12 frag_start = 0 frag_end = 11 hp_frag.copy_(lp_frag) lp3: full numel = 4 lp_frag numel = 4 start = 0 end = 3 hp_frag numel = 4 start = 12 end = 15 lp4: full numel = 12 lp_frag numel = 4 start = 0 end = 3 hp_frag numel = 4 start = 16 end = 19 Visual depiction of above lp { } flat_lp [ ] flat_hp ( ) flat_lp [ { ( } ) ] lx hx ly hy ly-hx lp { } flat_lp [ ] flat_hp ( ) flat_lp [ ( { ) } ] hx lx hy ly hy-lx lp { } flat_lp [ ] flat_hp ( ) flat_lp [ ( { } ) ] hx lx ly hy ly-lx lp -> (lx, hy) flat_hp -> (hx, hy) '''