Training API

deepspeed.initialize() returns a training engine in its first argument of type DeepSpeedEngine. This engine is used to progress training:

for step, batch in enumerate(data_loader):
    #forward() method
    loss = model_engine(batch)

    #runs backpropagation

    #weight update

Forward Propagation

deepspeed.DeepSpeedEngine.forward(*args, **kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Backward Propagation

deepspeed.DeepSpeedEngine.backward(*args, **kwargs)

Optimizer Step

deepspeed.DeepSpeedEngine.step(self, lr_kwargs=None)

Execute the weight update step after forward and backward propagation on effective_train_batch.

Gradient Accumulation


Query whether the current micro-batch is at the boundary of gradient accumulation, and thus will trigger gradient reductions and an optimizer step.


if the current step is a gradient accumulation boundary.

Return type


Model Saving

deepspeed.DeepSpeedEngine.save_16bit_model(self, save_dir, save_filename='pytorch_model.bin', exclude_frozen_parameters=False)

Save 16bit model weights

This method saves the 16bit model weights at the desired destination.

  • save_dir – Required. Directory for saving the model

  • save_filename – Optional. Filename to save to. Defaults to pytorch_model.bin

  • exclude_frozen_parameters – Optional. Exclude frozen parameters from checkpointed state.


True when a model has been saved, False otherwise. It will not be saved if stage3_gather_16bit_weights_on_model_save is False.

Important: all processes must call this method and not just the process with rank 0. It is because the processes need to work in sync to gather the weights. This method will hang waiting to synchronize with other processes if it’s called just for the process with rank 0.

Additionally when a DeepSpeed checkpoint is created, a script is added there which can be used to reconstruct fp32 master weights into a single pytorch state_dict file.