Model Checkpointing
DeepSpeed provides routines for checkpointing model state during training.
Loading Training Checkpoints
- deepspeed.DeepSpeedEngine.load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False, custom_load_fn=None)
Load training checkpoint
- Parameters
load_dir – Required. Directory to load the checkpoint from
tag – Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in ‘latest’ file
load_module_strict – Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states – Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM’s momentum and variance
load_lr_scheduler_states – Optional. Boolean to add the learning rate scheduler states from Checkpoint.
load_module_only – Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
custom_load_fn – Optional. Custom model load function.
- Returns
A tuple of
load_path
andclient_state
. *load_path
: Path of the loaded checkpoint.None
if loading the checkpoint failed. *client_state
: State dictionary used for loading required training states in the client code.
Important: under ZeRO3, one cannot load checkpoint with
engine.load_checkpoint()
right afterengine.save_checkpoint()
. It is becauseengine.module
is partitioned, andload_checkpoint()
wants a pristine model. If insisting to do so, please reinitialize engine beforeload_checkpoint()
.
Saving Training Checkpoints
- deepspeed.DeepSpeedEngine.save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False)
Save training checkpoint
- Parameters
save_dir – Required. Directory for saving the checkpoint
tag – Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided. Tag name must be the same across all ranks.
client_state – Optional. State dictionary used for saving required training states in the client code.
save_latest – Optional. Save a file ‘latest’ pointing to the latest saved checkpoint.
exclude_frozen_parameters – Optional. Exclude frozen parameters from checkpointed state.
Important: all processes must call this method and not just the process with rank 0. It is because each process needs to save its master weights and scheduler+optimizer states. This method will hang waiting to synchronize with other processes if it’s called just for the process with rank 0.
ZeRO Checkpoint fp32 Weights Recovery
DeepSpeed provides routines for extracting fp32 weights from the saved ZeRO checkpoint’s optimizer states.
- deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None)[source]
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
load_state_dict()
and used for training without DeepSpeed or shared with others, for example via a model hub.- Parameters
checkpoint_dir (-) – path to the desired checkpoint folder
tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in ‘latest’ file. e.g.,
global_step14
- Returns
pytorch
state_dict
Note: this approach may not work if your application doesn’t have sufficient free CPU memory and you may need to use the offline approach using the
zero_to_fp32.py
script that is saved with the checkpoint.A typical usage might be
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint # do the training and checkpoint saving state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu model = model.cpu() # move to cpu model.load_state_dict(state_dict) # submit to model hub or save the model to share with others
In this example the
model
will no longer be usable in the deepspeed context of the same application. i.e. you will need to re-initialize the deepspeed engine, sincemodel.load_state_dict(state_dict)
will remove all the deepspeed magic from it.If you want it all done for you, use
load_state_dict_from_zero_checkpoint
instead.
- deepspeed.utils.zero_to_fp32.load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None)[source]
Put the provided model to cpu
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated
state_dict
Load it into the provided model
- Parameters
model (-) – the model object to update
checkpoint_dir (-) – path to the desired checkpoint folder. (one that contains the tag-folder, like
global_step14
)tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named
latest
in the checkpoint folder, e.g.,global_step14
- Returns
modified model
- Return type
``model`
Make sure you have plenty of CPU memory available before you call this function. If you don’t have enough use the
zero_to_fp32.py
utility to do the conversion. You will find it conveniently placed for you in the checkpoint folder.A typical usage might be
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) # submit to model hub or save the model to share with others
Note, that once this was run, the
model
will no longer be usable in the deepspeed context of the same application. i.e. you will need to re-initialize the deepspeed engine, sincemodel.load_state_dict(state_dict)
will remove all the deepspeed magic from it.
- deepspeed.utils.zero_to_fp32.convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None)[source]
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated
state_dict
file that can be loaded withtorch.load(file)
+load_state_dict()
and used for training without DeepSpeed.- Parameters
checkpoint_dir (-) – path to the desired checkpoint folder. (one that contains the tag-folder, like
global_step14
)output_file (-) – path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named
latest
in the checkpoint folder, e.g.,global_step14
Avoiding ZeRO Checkpoint Bloat
ZeRO stage 1 and 2 checkpoints created using torch.save()
can sometimes be larger than expected. This bloat
is caused by the interaction of ZeRO’s tensor flattening and torch’s tensor storage management .
You can avoid this problem by using the clone_tensors_for_torch_save
utility of DeepSpeed as illustrated below.
- deepspeed.checkpoint.utils.clone_tensors_for_torch_save(item, device=device(type='cpu'))[source]
Returns a copy of
item
with all enclosed tensors replaced by clones on a specified device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.- Parameters
item (-) – tensor to clone or (possibly nested) container of tensors to clone.
device (-) – target device (defaults to ‘cpu’)
- Returns
copy of
item
with cloned tensors on target device
The following code snippet illustrates this functionality for creating a HuggingFace model checkpoint:
ds_config = {
...
}
model = AutoModelForCausalLM.from_pretrained("facebook/opt-13b", torch_dtype=torch.float16)
ds_engine, _, _, _ = deepspeed.initialize(model=model, config_params=ds_config)
lean_state_dict = deepspeed.checkpoint.utils.clone_tensors_for_torch_save(ds_engine.module.state_dict())
ds_engine.module.save_pretrained("lean_after", state_dict=lean_state_dict)
Universal Checkpoints (under development)
Parallelism techniques such as ZeRO data parallelism (DP), Tensor parallelism (TP), Pipeline parallelism (TP), which shard model and/or optimizer states make it difficult to resume training with a checkpoint that was created on a different number of GPUs. DeepSpeed provides the Universal Checkpoint mechanism to address this problem. Universal Checkpoints give users the flexibility of changing the number of GPUs when training with 3D (TP, PP, and DP) parallelism, and enables more efficient use of elastic training hardware. The easiest way to get started with using Universal Checkpoints is to consult the Megatron-DeepSpeed and BLOOM examples.