Pipeline Parallelism

Model Specification

class deepspeed.pipe.PipelineModule(layers, num_stages=None, topology=None, loss_fn=None, seed_layers=False, seed_fn=None, base_seed=1234, partition_method='parameters', activation_checkpoint_interval=0, activation_checkpoint_func=<function checkpoint>, checkpointable_layers=None)[source]

Modules to be parallelized with pipeline parallelism.

The key constraint that enables pipeline parallelism is the representation of the forward pass as a sequence of layers and the enforcement of a simple interface between them. The forward pass is implicitly defined by the module layers. The key assumption is that the output of each layer can be directly fed as input to the next, like a torch.nn.Sequence. The forward pass is implicitly:

def forward(self, inputs):
    x = inputs
    for layer in self.layers:
        x = layer(x)
    return x

Note

Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.

Parameters
  • layers (Iterable) – A sequence of layers defining pipeline structure. Can be a torch.nn.Sequential module.

  • num_stages (int, optional) – The degree of pipeline parallelism. If not specified, topology must be provided.

  • topology (deepspeed.runtime.pipe.ProcessTopology, optional) – Defines the axes of parallelism axes for training. Must be provided if num_stages is None.

  • loss_fn (callable, optional) – Loss is computed loss = loss_fn(outputs, label)

  • seed_layers (bool, optional) – Use a different seed for each layer. Defaults to False.

  • seed_fn (type, optional) – The custom seed generating function. Defaults to random seed generator.

  • base_seed (int, optional) – The starting seed. Defaults to 1234.

  • partition_method (str, optional) – The method upon which the layers are partitioned. Defaults to ‘parameters’.

  • activation_checkpoint_interval (int, optional) – The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.

  • activation_checkpoint_func (callable, optional) – The function to use for activation checkpointing. Defaults to deepspeed.checkpointing.checkpoint.

  • checkpointable_layers (list, optional) – Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.

forward(forward_input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

allreduce_tied_weight_gradients()[source]

All reduce the gradients of the tied weights between tied stages

topology()[source]

ProcessTopology object to query process mappings.

ckpt_prefix(checkpoints_path, tag)[source]

Build a prefix for all checkpoint files written by this module.

ckpt_layer_path(ckpt_dir, local_layer_idx)[source]

Customize a prefix for a specific pipeline module layer.

ckpt_layer_path_list(ckpt_dir, local_layer_idx)[source]

Get all ckpt file list for a specific pipeline module layer.

class deepspeed.pipe.LayerSpec(typename, *module_args, **module_kwargs)[source]

Building block for specifying pipeline-parallel modules.

LayerSpec stores the type information and parameters for each stage in a PipelineModule. For example:

nn.Sequence(
    torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False),
    torch.nn.Linear(self.hidden_hidden, self.out_dim)
)

becomes

layer_specs = [
    LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False),
    LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
]
build(log=False)[source]

Build the stored specification.

class deepspeed.pipe.TiedLayerSpec(key, typename, *module_args, forward_fn=None, tied_weight_attr=['weight'], **module_kwargs)[source]
class deepspeed.runtime.pipe.ProcessTopology(axes, dims)[source]

Manages the mapping of n-dimensional Cartesian coordinates to linear indices. This mapping is used to map the rank of processes to the grid for various forms of parallelism.

Each axis of the tensor is accessed by its name. The provided ordering of the axes defines the layout of the topology. ProcessTopology uses a “row-major” layout of the tensor axes, and so axes=[‘x’, ‘y’] would map coordinates (x,y) and (x,y+1) to adjacent linear indices. If instead axes=[‘y’, ‘x’] was used, coordinates (x,y) and (x+1,y) would be adjacent.

Some methods return ProcessCoord namedtuples.

get_rank(**coord_kwargs)[source]

Return the global rank of a process via its coordinates.

Coordinates are specified as kwargs. For example:

>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_rank(x=0, y=1)
1
get_axis_names()[source]

Return a list of the axis names in the ordering of the topology.

get_rank_repr(rank, omit_axes=['data', 'pipe'], inner_sep='_', outer_sep='-')[source]

Return a string representation of a rank.

This method is primarily used for checkpointing model data.

For example:
>>> topo = Topo(axes=['a', 'b'], dims=[2, 2])
>>> topo.get_rank_repr(rank=3)
'a_01-b_01'
>>> topo.get_rank_repr(rank=3, omit_axes=['a'])
'b_01'
Parameters
  • rank (int) – A rank in the topology.

  • omit_axes (list, optional) – Axes that should not be in the representation. Defaults to [‘data’, ‘pipe’].

  • inner_sep (str, optional) – [description]. Defaults to ‘_’.

  • outer_sep (str, optional) – [description]. Defaults to ‘-‘.

Returns

A string representation of the coordinate owned by rank.

Return type

str

get_dim(axis)[source]

Return the number of processes along the given axis.

For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_dim('y')
3
get_coord(rank)[source]

Return the coordinate owned by a process rank.

The axes of the returned namedtuple can be directly accessed as members. For .. rubric:: Example

>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> coord = X.get_coord(rank=1)
>>> coord.x
0
>>> coord.y
1
get_axis_comm_lists(axis)[source]

Construct lists suitable for a communicator group along axis axis.

Example

>>> topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
>>> topo.get_axis_comm_lists('pipe')
[
    [0, 4], # data=0, model=0
    [1, 5], # data=0, model=1
    [2, 6], # data=1, model=0
    [3, 7], # data=1, model=1
]
Returns

A list of lists whose coordinates match in all axes except axis.

filter_match(**filter_kwargs)[source]

Return the list of ranks whose coordinates match the provided criteria.

Example

>>> X = ProcessTopology(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
>>> X.filter_match(pipe=0, data=1)
[2, 3]
>>> [X.get_coord(rank) for rank in X.filter_match(pipe=0, data=1)]
[ProcessCoord(pipe=0, data=1, model=0), ProcessCoord(pipe=0, data=1, model=1)]
Parameters

**filter_kwargs (dict) – criteria used to select coordinates.

Returns

The list of ranks whose coordinates match filter_kwargs.

get_axis_list(axis, idx)[source]

Returns the list of global ranks whose coordinate in an axis is idx.

For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])
>>> X.get_axis_list(axis='x', idx=0)
[0, 1, 2]
>>> X.get_axis_list(axis='y', idx=0)
[0, 3]

Training

class deepspeed.runtime.pipe.engine.PipelineEngine(has_bool_tensors=False, *super_args, **super_kwargs)[source]

A training engine hybrid pipeline, data, and model parallel training.

This engine is created by deepspeed.initialize() when a PipelineModule is provided.

reset_activation_shape()[source]

Reset the buffers when the shape of activation and gradient change. For example, for curriculum learning that changes the seqlen of each sample, we need to call this whenever the seqlen is going to change.

train_batch(data_iter=None)[source]

Progress the pipeline to train the next batch of data. The engine will ingest self.train_batch_size() total samples collectively across all workers.

An iterator that over training data should be provided as an argument unless deepspeed.initialize() was provided a training set. In that event, the training data will automatically be read.

Warning

A total of self.gradient_accumulation_steps() entries will be pulled from data_iter by each pipeline. There must be sufficient data left in data_iter or else a StopIteration will halt training.

DeepSpeed provides a convenience class deepspeed.utils.RepeatingLoader that wraps data loaders to automatically restart upon a StopIteration.

Parameters

data_iter (Iterator, optional) – Iterator of training data.

Returns

The arithmetic mean of the losses computed this batch.

eval_batch(data_iter, return_logits=False, compute_loss=True, reduce_output='avg', bcast_loss=True, num_micro_batches=None)[source]

Evaluate the pipeline on a batch of data from data_iter. The engine will evaluate self.train_batch_size() total samples collectively across all workers.

This method is equivalent to:

module.eval()
with torch.no_grad():
    output = module(batch)

Warning

A total of self.gradient_accumulation_steps() entries will be pulled from data_iter by each pipeline. There must be sufficient data left in data_iter or else a StopIteration will halt training.

DeepSpeed provides a convenience class deepspeed.utils.RepeatingLoader that wraps data loaders to automatically restart upon a StopIteration.

Parameters

data_iter (Iterator) – Iterator of data to evaluate.

Returns

The arithmetic mean of the losses computed this batch.

set_train_batch_size(train_batch_size)[source]

Adjust the global batch size by increasing or decreasing the number of micro-batches (i.e., gradient accumulation steps). The size of each micro-batch (i.e., train_micro_batch_size_per_gpu) is not changed. :param train_batch_size: The new global batch size for training. :type train_batch_size: int

Raises

ValueError – if train_batch_size is not divisible by the configured micro-batch size and data parallelism.

is_first_stage()[source]

True if this process is in the first stage in the pipeline.

is_last_stage()[source]

True if this process is in the last stage in the pipeline.

set_dataiterator(iterator)[source]

Store an iterator to sample for training data.

set_batch_fn(fn)[source]

Execute a post-processing function on input data.

Parameters

fn (function) – The function to run.

is_gradient_accumulation_boundary()[source]

True if the engine is executing a gradient reduction or optimizer step instruction.

This is overridden from DeepSpeedEngine to force reductions and steps when the pipeline engine is instructed to do so.

Returns

whether reductions and optimizer steps should occur.

Return type

bool

forward(*args, **kwargs)[source]

Disabled for pipeline parallel training. See train_batch().

backward(*args, **kwargs)[source]

Disabled for pipeline parallel training. See train_batch().

step(*args, **kwargs)[source]

Disabled for pipeline parallel training. See train_batch().

module_state_dict(exclude_frozen_parameters=False)[source]

Override hack to save a pipe model and return the directory path of the save.

This method should only be called by DeepSpeed’s save_checkpoint(). The recommended way of saving a PipelineModule outside of save_checkpoint() is save_state_dict().

Returns

None

load_module_state_dict(checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False)[source]

Override hack to instead use a directory path.

This is important because pipeline models checkpoint by layer instead of rank.

If state_dict is not None or a str, we revert to super() expecting a dict.

Parameters
  • state_dict (str, None) – unused

  • strict (bool, optional) – Strict state loading. Defaults to True.

Extending Pipeline Parallelism

class deepspeed.runtime.pipe.schedule.PipeSchedule(micro_batches, stages, stage_id)[source]

Directs the execution of a pipeline engine by generating sequences of PipeInstruction.

Schedules are generators that yield sequences of PipeInstruction to process the micro-batches in one batch. Each yielded step is atomic in the sense that a barrier synchronization can be placed between successive steps without deadlock.

Below is an example schedule that implements data parallelism with gradient accumulation:

class DataParallelSchedule(PipeSchedule):
    def steps(self):
        for step_id in range(self.micro_batches):
            cmds = [
                LoadMicroBatch(buffer_id=0),
                ForwardPass(buffer_id=0),
                BackwardPass(buffer_id=0),
            ]
            if step_id == self.micro_batches - 1:
                cmds.extend([
                    ReduceGrads(),
                    OptimizerStep(),
                ])
            yield cmds

    def num_pipe_buffers(self):
        return 1
Parameters
  • micro_batches (int) – The number of micro-batches that comprise a batch.

  • stages (int) – The number of pipeline stages.

  • stage_id (int) – The pipe stage that will execute the generated schedule.

abstract steps()[source]

Yield a list of PipeInstruction for each step in the schedule.

Note

Schedules must implement steps() to define the schedule.

Returns

Instructions to be executed as one step of the pipeline

num_pipe_buffers()[source]

The number of pipeline buffers that will be used by this stage.

Note

Schedules should specialize num_pipe_buffers() for memory savings at scale.

Returns

The number of buffers for the engine to allocate.

property stage

Stage index used to configure this schedule.

property num_stages

The number of total pipeline stages used to configure this schedule.

property num_micro_batches

The number of total micro_batches used to configure this schedule.

property is_first_stage

True if the configured stage_id is the first stage in the pipeline.

property is_last_stage

True if the configured stage_id is the last stage in the pipeline.

class deepspeed.runtime.pipe.schedule.InferenceSchedule(micro_batches, stages, stage_id)[source]

A schedule for inferencing batches using pipeline parallelism.

num_pipe_buffers()[source]

Only two pipeline buffers are required for inferencing.

Returns

2

class deepspeed.runtime.pipe.schedule.TrainSchedule(micro_batches, stages, stage_id)[source]

A schedule for training a batch using hybrid parallelism.

Pipeline parallelism is extracted through gradient accumulation and thus convergence follows that of a data parallel approach with the same batch size.

num_pipe_buffers()[source]

Return the number of pipeline buffers required for this stage.

This is equivalent to the maximum number of in-flight forward passes, since we need to remember the activations of forward passes in order to run backpropagation. For synchronous 1F1B, this is equivalent to the index difference between this stage and the last stage.

class deepspeed.runtime.pipe.schedule.DataParallelSchedule(micro_batches, stages, stage_id)[source]

An example schedule that trains using traditional data parallelism with gradient accumulation.

num_pipe_buffers()[source]

Only one pipeline buffer needed.

class deepspeed.runtime.pipe.schedule.PipeInstruction(**kwargs)[source]

Base class for all instructions to be executed by the pipeline engine.

All keyword arguments are stored as members similar to a namedtuple. These are then accessible to the PipeEngine during execution.

Parameters

kwargs (optional) – keyword arguments to store as members

class deepspeed.runtime.pipe.schedule.OptimizerStep(**kwargs)[source]

Performs one step with the optimizer and zeros gradients.

Note

Should be issued after ReduceGrads and ReduceTiedGrads.

Note

Can be a synchronization point among data-parallel ranks.

class deepspeed.runtime.pipe.schedule.ReduceGrads(**kwargs)[source]

Reduce the computed gradients among data-parallel processes within the stage.

class deepspeed.runtime.pipe.schedule.ReduceTiedGrads(**kwargs)[source]

Reduce the computed gradients of tied modules within a pipeline-parallel group.

Warning

The stages included in this synchronization point are not known until the model is partitioned among pipeline stages. In the worst case, it includes all pipeline stages. This instruction should be scheduled carefully to avoid deadlocks.

class deepspeed.runtime.pipe.schedule.BufferOpInstruction(buffer_id, **kwargs)[source]

A pipeline instruction that operates on pipeline buffer(s).

Parameters

buffer_id (int) – the index of the pipeline buffer() to modify.

class deepspeed.runtime.pipe.schedule.LoadMicroBatch(buffer_id, **kwargs)[source]

Load a micro-batch into a buffer.

Roughly:

buffers['inputs'][buffer_id] = next(data_iter)
class deepspeed.runtime.pipe.schedule.ForwardPass(buffer_id, **kwargs)[source]

Compute a forward pass.

Roughly:

buffers['outputs'][buffer_id] = forward(buffers['inputs'][buffer_id])
class deepspeed.runtime.pipe.schedule.BackwardPass(buffer_id, **kwargs)[source]

Compute a backward pass and accumulate gradients.

Roughly:

outputs = buffers['outputs'][buffer_id]
gradients = buffers['gradients'][buffer_id]
torch.autograd.backward(tensors=outputs,
                        grad_tensors=gradients)
class deepspeed.runtime.pipe.schedule.SendActivation(buffer_id, **kwargs)[source]

Send activations to the next stage in the pipeline.

Roughly:

send(buffers['outputs'][buffer_id])

Note

The communication is blocking and must be paired with a RecvActivation on the next pipeline stage to avoid deadlock.

class deepspeed.runtime.pipe.schedule.RecvActivation(buffer_id, **kwargs)[source]

Receive activations from the previous stage in the pipeline.

Roughly:

buffers['inputs'][buffer_id] = recv()

Note

The communication is blocking and must be paired with a SendActivation on the previous pipeline stage to avoid deadlock.

class deepspeed.runtime.pipe.schedule.SendGrad(buffer_id, **kwargs)[source]

Send computed gradients to the previous pipeline stage. with respect to the received activations

Note

Only received tensors with requires_grad==True will produce gradients. Missing gradients will be replaced with None on the receiving stage.

Note

The communication is blocking and must be paired with a RecvGrad on the previous pipeline stage to avoid deadlock.

class deepspeed.runtime.pipe.schedule.RecvGrad(buffer_id, **kwargs)[source]

Receive computed gradients the next pipeline stage.

Note

Only activations with requires_grad==True will produce gradients. Missing gradients will be replaced with None.

Note

The communication is blocking and must be paired with a SendGrad on the next pipeline stage to avoid deadlock.