Training API

deepspeed.initialize() returns a training engine in its first argument of type DeepSpeedEngine. This engine is used to progress training:

for step, batch in enumerate(data_loader):
    #forward() method
    loss = model_engine(batch)

    #runs backpropagation

    #weight update

Forward Propagation

Backward Propagation

Optimizer Step

Gradient Accumulation

Model Saving

Additionally when a DeepSpeed checkpoint is created, a script is added there which can be used to reconstruct fp32 master weights into a single pytorch state_dict file.