Model Checkpointing

DeepSpeed provides routines for checkpointing model state during training.

Loading Training Checkpoints

Saving Training Checkpoints

ZeRO Checkpoint fp32 Weights Recovery

DeepSpeed provides routines for extracting fp32 weights from the saved ZeRO checkpoint’s optimizer states.

deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False, lazy_mode=False)[source]

Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with load_state_dict() and used for training without DeepSpeed or shared with others, for example via a model hub.

Parameters
  • checkpoint_dir (-) – path to the desired checkpoint folder

  • tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in ‘latest’ file. e.g., global_step14

  • exclude_frozen_parameters (-) – exclude frozen parameters

  • lazy_mode (-) – get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient. Convert the pesduo tensor to torch tensor by .contiguous()

Returns

  • pytorch state_dict

A typical usage might be

from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# do the training and checkpoint saving
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
# submit to model hub or save the model to share with others

In this example the model will no longer be usable in the deepspeed context of the same application. i.e. you will need to re-initialize the deepspeed engine, since model.load_state_dict(state_dict) will remove all the deepspeed magic from it.

If you want it all done for you, use load_state_dict_from_zero_checkpoint instead.

Note: the above usage may not work if your application doesn’t have sufficient free CPU memory. You may need to use the offline approach using the zero_to_fp32.py script that is saved with the checkpoint. Or you can load state_dict in lazy mode

from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
for name, lazy_tensor in state_dict.item():
    tensor = lazy_tensor.contiguous()  # to cpu
    print(name, tensor)
    # del tensor to release memory if it no longer in use
deepspeed.utils.zero_to_fp32.load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None)[source]
  1. Put the provided model to cpu

  2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict

  3. Load it into the provided model

Parameters
  • model (-) – the model object to update

  • checkpoint_dir (-) – path to the desired checkpoint folder. (one that contains the tag-folder, like global_step14)

  • tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named latest in the checkpoint folder, e.g., global_step14

Returns

modified model

Return type

  • ``model`

Make sure you have plenty of CPU memory available before you call this function. If you don’t have enough use the zero_to_fp32.py utility to do the conversion. You will find it conveniently placed for you in the checkpoint folder.

A typical usage might be

from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
# submit to model hub or save the model to share with others

Note, that once this was run, the model will no longer be usable in the deepspeed context of the same application. i.e. you will need to re-initialize the deepspeed engine, since model.load_state_dict(state_dict) will remove all the deepspeed magic from it.

deepspeed.utils.zero_to_fp32.convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_dir, max_shard_size='5GB', safe_serialization=False, tag=None, exclude_frozen_parameters=False)[source]

Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be loaded with torch.load(file) + load_state_dict() and used for training without DeepSpeed.

Parameters
  • checkpoint_dir (-) – path to the desired checkpoint folder. (one that contains the tag-folder, like global_step14)

  • output_dir (-) – directory to the pytorch fp32 state_dict output files

  • max_shard_size (-) – the maximum size for a checkpoint before being sharded, default value is 5GB

  • safe_serialization (-) – whether to save the model using safetensors or the traditional PyTorch way (that uses pickle).

  • tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named latest in the checkpoint folder, e.g., global_step14

  • exclude_frozen_parameters (-) – exclude frozen parameters

Avoiding ZeRO Checkpoint Bloat

ZeRO stage 1 and 2 checkpoints created using torch.save() can sometimes be larger than expected. This bloat is caused by the interaction of ZeRO’s tensor flattening and torch’s tensor storage management . You can avoid this problem by using the clone_tensors_for_torch_save utility of DeepSpeed as illustrated below.

deepspeed.checkpoint.utils.clone_tensors_for_torch_save(item, device=device(type='cpu'))[source]

Returns a copy of item with all enclosed tensors replaced by clones on a specified device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.

Parameters
  • item (-) – tensor to clone or (possibly nested) container of tensors to clone.

  • device (-) – target device (defaults to ‘cpu’)

Returns

  • copy of item with cloned tensors on target device

The following code snippet illustrates this functionality for creating a HuggingFace model checkpoint:

ds_config = {
 ...
}
model = AutoModelForCausalLM.from_pretrained("facebook/opt-13b", torch_dtype=torch.float16)
ds_engine, _, _, _ = deepspeed.initialize(model=model, config_params=ds_config)
lean_state_dict = deepspeed.checkpoint.utils.clone_tensors_for_torch_save(ds_engine.module.state_dict())
ds_engine.module.save_pretrained("lean_after", state_dict=lean_state_dict)

Universal Checkpoints (under development)

Parallelism techniques such as ZeRO data parallelism (DP), Tensor parallelism (TP), Pipeline parallelism (TP), which shard model and/or optimizer states make it difficult to resume training with a checkpoint that was created on a different number of GPUs. DeepSpeed provides the Universal Checkpoint mechanism to address this problem. Universal Checkpoints give users the flexibility of changing the number of GPUs when training with 3D (TP, PP, and DP) parallelism, and enables more efficient use of elastic training hardware. The easiest way to get started with using Universal Checkpoints is to consult the Megatron-DeepSpeed and BLOOM examples.