Model Checkpointing

DeepSpeed provides routines for checkpointing model state during training.

Loading Training Checkpoints

deepspeed.DeepSpeedEngine.load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False)

Load training checkpoint

Parameters:
  • load_dir – Required. Directory to load the checkpoint from
  • tag – Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in ‘latest’ file
  • load_module_strict – Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
  • load_optimizer_states – Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM’s momentum and variance
  • load_lr_scheduler_states – Optional. Boolean to add the learning rate scheduler states from Checkpoint.
  • load_module_only – Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
Returns:

A tuple of load_path and client_state.

*load_path: Path of the loaded checkpoint. None if loading the checkpoint failed.

*client_state: State dictionary used for loading required training states in the client code.

Saving Training Checkpoints

deepspeed.DeepSpeedEngine.save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)

Save training checkpoint

Parameters:
  • save_dir – Required. Directory for saving the checkpoint
  • tag – Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided. Tag name must be the same across all ranks.
  • client_state – Optional. State dictionary used for saving required training states in the client code.
  • save_latest – Optional. Save a file ‘latest’ pointing to the latest saved checkpoint.

Important: all processes must call this method and not just the process with rank 0. It is because each process needs to save its master weights and scheduler+optimizer states. This method will hang waiting to synchronize with other processes if it’s called just for the process with rank 0.

ZeRO Checkpoint fp32 Weights Recovery

DeepSpeed provides routines for extracting fp32 weights from the saved ZeRO checkpoint’s optimizer states.

deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None)[source]

Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with load_state_dict() and used for training without DeepSpeed or shared with others, for example via a model hub.

Parameters:
  • checkpoint_dir (-) – path to the desired checkpoint folder
  • tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in ‘latest’ file. e.g., global_step14
Returns:

  • pytorch state_dict

Note: this approach may not work if your application doesn’t have sufficient free CPU memory and you may need to use the offline approach using the zero_to_fp32.py script that is saved with the checkpoint.

A typical usage might be

from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# do the training and checkpoint saving
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
# submit to model hub or save the model to share with others

In this example the model will no longer be useable in the deepspeed context of the same application. i.e. you will need to re-initialize the deepspeed engine, since model.load_state_dict(state_dict) will remove all the deepspeed magic from it.

If you want it all done for you, use load_state_dict_from_zero_checkpoint instead.

deepspeed.utils.zero_to_fp32.load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None)[source]
  1. Put the provided model to cpu
  2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict
  3. Load it into the provided model
Parameters:
  • model (-) – the model object to update
  • checkpoint_dir (-) – path to the desired checkpoint folder. (one that contains the tag-folder, like global_step14)
  • tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named latest in the checkpoint folder, e.g., global_step14
Returns:

modified model

Return type:

  • ``model`

Make sure you have plenty of CPU memory available before you call this function. If you don’t have enough use the zero_to_fp32.py utility to do the conversion. You will find it conveniently placed for you in the checkpoint folder.

A typical usage might be

from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
# submit to model hub or save the model to share with others

Note, that once this was run, the model will no longer be useable in the deepspeed context of the same application. i.e. you will need to re-initialize the deepspeed engine, since model.load_state_dict(state_dict) will remove all the deepspeed magic from it.

deepspeed.utils.zero_to_fp32.convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None)[source]

Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be loaded with torch.load(file) + load_state_dict() and used for training without DeepSpeed.

Parameters:
  • checkpoint_dir (-) – path to the desired checkpoint folder. (one that contains the tag-folder, like global_step14)
  • output_file (-) – path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
  • tag (-) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named latest in the checkpoint folder, e.g., global_step14