Training API

deepspeed.initialize() returns a training engine in its first argument of type DeepSpeedEngine. This engine is used to progress training:

for step, batch in enumerate(data_loader):
    #forward() method
    loss = model_engine(batch)

    #runs backpropagation
    model_engine.backward(loss)

    #weight update
    model_engine.step()

Forward Propagation

deepspeed.DeepSpeedEngine.forward(self, *inputs, **kwargs)

Execute forward propagation

Parameters:
  • *inputs – Variable length input list
  • **kwargs – variable length keyword arguments

Backward Propagation

deepspeed.DeepSpeedEngine.backward(self, loss, allreduce_gradients=True, release_loss=False)

Execute backward pass on the loss

Parameters:
  • loss – Torch tensor on which to execute backward propagation
  • allreduce_gradients – is deprecated, ignored, and will soon be removed’

Optimizer Step

deepspeed.DeepSpeedEngine.step(self, lr_kwargs=None)

Execute the weight update step after forward and backward propagation on effective_train_batch.

Gradient Accumulation

deepspeed.DeepSpeedEngine.is_gradient_accumulation_boundary(self)

Query whether the current micro-batch is at the boundary of gradient accumulation, and thus will trigger gradient reductions and an optimizer step.

Returns:if the current step is a gradient accumulation boundary.
Return type:bool

Model Saving

deepspeed.DeepSpeedEngine.save_fp16_model(self, save_dir, save_filename='pytorch_model.bin')

Save fp16 model weights

This method saves the fp16 model weights at the desired destination.

Parameters:
  • save_dir – Required. Directory for saving the model
  • save_filename – Optional. Filename to save to. Defaults to pytorch_model.bin

Important: all processes must call this method and not just the process with rank 0. It is because the processes need to work in sync to gather the weights. This method will hang waiting to synchronize with other processes if it’s called just for the process with rank 0.

Additionally when a DeepSpeed checkpoint is created, a script zero_to_fp32.py is added there which can be used to reconstruct fp32 master weights into a single pytorch state_dict file.