Pipeline Parallelism¶
Model Specification¶
- class deepspeed.pipe.PipelineModule(layers, num_stages=None, topology=None, loss_fn=None, seed_layers=False, seed_fn=None, base_seed=1234, partition_method='parameters', activation_checkpoint_interval=0, activation_checkpoint_func=<function checkpoint>, checkpointable_layers=None)[source]¶
Modules to be parallelized with pipeline parallelism.
The key constraint that enables pipeline parallelism is the representation of the forward pass as a sequence of layers and the enforcement of a simple interface between them. The forward pass is implicitly defined by the module
layers
. The key assumption is that the output of each layer can be directly fed as input to the next, like atorch.nn.Sequence
. The forward pass is implicitly:def forward(self, inputs): x = inputs for layer in self.layers: x = layer(x) return x
Note
Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
- Parameters
layers (Iterable) – A sequence of layers defining pipeline structure. Can be a
torch.nn.Sequential
module.num_stages (int, optional) – The degree of pipeline parallelism. If not specified,
topology
must be provided.topology (
deepspeed.runtime.pipe.ProcessTopology
, optional) – Defines the axes of parallelism axes for training. Must be provided ifnum_stages
isNone
.loss_fn (callable, optional) – Loss is computed
loss = loss_fn(outputs, label)
seed_layers (bool, optional) – Use a different seed for each layer. Defaults to False.
seed_fn (type, optional) – The custom seed generating function. Defaults to random seed generator.
base_seed (int, optional) – The starting seed. Defaults to 1234.
partition_method (str, optional) – The method upon which the layers are partitioned. Defaults to ‘parameters’.
activation_checkpoint_interval (int, optional) – The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional) – The function to use for activation checkpointing. Defaults to
deepspeed.checkpointing.checkpoint
.checkpointable_layers (list, optional) – Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
- forward(forward_input)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- allreduce_tied_weight_gradients()[source]¶
All reduce the gradients of the tied weights between tied stages
- ckpt_prefix(checkpoints_path, tag)[source]¶
Build a prefix for all checkpoint files written by this module.
- class deepspeed.pipe.LayerSpec(typename, *module_args, **module_kwargs)[source]¶
Building block for specifying pipeline-parallel modules.
LayerSpec stores the type information and parameters for each stage in a PipelineModule. For example:
nn.Sequence( torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False), torch.nn.Linear(self.hidden_hidden, self.out_dim) )
becomes
layer_specs = [ LayerSpec(torch.nn.Linear, self.in_dim, self.hidden_dim, bias=False), LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)] ]
- class deepspeed.pipe.TiedLayerSpec(key, typename, *module_args, forward_fn=None, tied_weight_attr='weight', **module_kwargs)[source]¶
- class deepspeed.runtime.pipe.ProcessTopology(axes, dims)[source]¶
Manages the mapping of n-dimensional Cartesian coordinates to linear indices. This mapping is used to map the rank of processes to the grid for various forms of parallelism.
Each axis of the tensor is accessed by its name. The provided ordering of the axes defines the layout of the topology. ProcessTopology uses a “row-major” layout of the tensor axes, and so axes=[‘x’, ‘y’] would map coordinates (x,y) and (x,y+1) to adjacent linear indices. If instead axes=[‘y’, ‘x’] was used, coordinates (x,y) and (x+1,y) would be adjacent.
Some methods return ProcessCoord namedtuples.
- get_rank(**coord_kwargs)[source]¶
Return the global rank of a process via its coordinates.
Coordinates are specified as kwargs. For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3]) >>> X.get_rank(x=0, y=1) 1
- get_rank_repr(rank, omit_axes=['data', 'pipe'], inner_sep='_', outer_sep='-')[source]¶
Return a string representation of a rank.
This method is primarily used for checkpointing model data.
- For example:
>>> topo = Topo(axes=['a', 'b'], dims=[2, 2]) >>> topo.get_rank_repr(rank=3) 'a_01-b_01' >>> topo.get_rank_repr(rank=3, omit_axes=['a']) 'b_01'
- Parameters
rank (int) – A rank in the topology.
omit_axes (list, optional) – Axes that should not be in the representation. Defaults to [‘data’, ‘pipe’].
inner_sep (str, optional) – [description]. Defaults to ‘_’.
outer_sep (str, optional) – [description]. Defaults to ‘-‘.
- Returns
A string representation of the coordinate owned by
rank
.- Return type
str
- get_dim(axis)[source]¶
Return the number of processes along the given axis.
- For example:
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3]) >>> X.get_dim('y') 3
- get_coord(rank)[source]¶
Return the coordinate owned by a process rank.
The axes of the returned namedtuple can be directly accessed as members. For .. rubric:: Example
>>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3]) >>> coord = X.get_coord(rank=1) >>> coord.x 0 >>> coord.y 1
- get_axis_comm_lists(axis)[source]¶
Construct lists suitable for a communicator group along axis
axis
.Example
>>> topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2]) >>> topo.get_axis_comm_lists('pipe') [ [0, 4], # data=0, model=0 [1, 5], # data=0, model=1 [2, 6], # data=1, model=0 [3, 7], # data=1, model=1 ]
- Returns
A list of lists whose coordinates match in all axes except
axis
.
- filter_match(**filter_kwargs)[source]¶
Return the list of ranks whose coordinates match the provided criteria.
Example
>>> X = ProcessTopology(axes=['pipe', 'data', 'model'], dims=[2, 2, 2]) >>> X.filter_match(pipe=0, data=1) [2, 3] >>> [X.get_coord(rank) for rank in X.filter_match(pipe=0, data=1)] [ProcessCoord(pipe=0, data=1, model=0), ProcessCoord(pipe=0, data=1, model=1)]
- Parameters
**filter_kwargs (dict) – criteria used to select coordinates.
- Returns
The list of ranks whose coordinates match filter_kwargs.
Training¶
- class deepspeed.runtime.pipe.engine.PipelineEngine(has_bool_tensors=False, *super_args, **super_kwargs)[source]¶
A training engine hybrid pipeline, data, and model parallel training.
This engine is created by
deepspeed.initialize()
when aPipelineModule
is provided.- reset_activation_shape()[source]¶
Reset the buffers when the shape of activation and gradient change. For example, for curriculum learning that changes the seqlen of each sample, we need to call this whenever the seqlen is going to change.
- train_batch(data_iter=None)[source]¶
Progress the pipeline to train the next batch of data. The engine will ingest
self.train_batch_size()
total samples collectively across all workers.An iterator that over training data should be provided as an argument unless
deepspeed.initialize()
was provided a training set. In that event, the training data will automatically be read.Warning
A total of
self.gradient_accumulation_steps()
entries will be pulled fromdata_iter
by each pipeline. There must be sufficient data left indata_iter
or else aStopIteration
will halt training.DeepSpeed provides a convenience class
deepspeed.utils.RepeatingLoader
that wraps data loaders to automatically restart upon aStopIteration
.- Parameters
data_iter (Iterator, optional) – Iterator of training data.
- Returns
The arithmetic mean of the losses computed this batch.
- eval_batch(data_iter, return_logits=False, compute_loss=True, reduce_output='avg')[source]¶
Evaluate the pipeline on a batch of data from
data_iter
. The engine will evaluateself.train_batch_size()
total samples collectively across all workers.This method is equivalent to:
module.eval() with torch.no_grad(): output = module(batch)
Warning
A total of
self.gradient_accumulation_steps()
entries will be pulled fromdata_iter
by each pipeline. There must be sufficient data left indata_iter
or else aStopIteration
will halt training.DeepSpeed provides a convenience class
deepspeed.utils.RepeatingLoader
that wraps data loaders to automatically restart upon aStopIteration
.- Parameters
data_iter (Iterator) – Iterator of data to evaluate.
- Returns
The arithmetic mean of the losses computed this batch.
- set_train_batch_size(train_batch_size)[source]¶
Adjust the global batch size by increasing or decreasing the number of micro-batches (i.e., gradient accumulation steps). The size of each micro-batch (i.e.,
train_micro_batch_size_per_gpu
) is not changed. :param train_batch_size: The new global batch size for training. :type train_batch_size: int- Raises
ValueError – if
train_batch_size
is not divisible by the configured micro-batch size and data parallelism.
- set_batch_fn(fn)[source]¶
Execute a post-processing function on input data.
- Parameters
fn (function) – The function to run.
- is_gradient_accumulation_boundary()[source]¶
True if the engine is executing a gradient reduction or optimizer step instruction.
This is overridden from
DeepSpeedEngine
to force reductions and steps when the pipeline engine is instructed to do so.- Returns
whether reductions and optimizer steps should occur.
- Return type
bool
- module_state_dict(exclude_frozen_parameters=False)[source]¶
Override hack to save a pipe model and return the directory path of the save.
This method should only be called by DeepSpeed’s
save_checkpoint()
. The recommended way of saving aPipelineModule
outside ofsave_checkpoint()
issave_state_dict()
.- Returns
None
- load_module_state_dict(checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False)[source]¶
Override hack to instead use a directory path.
This is important because pipeline models checkpoint by layer instead of rank.
If
state_dict
is notNone
or astr
, we revert tosuper()
expecting adict
.- Parameters
state_dict (str, None) – unused
strict (bool, optional) – Strict state loading. Defaults to True.
Extending Pipeline Parallelism¶
- class deepspeed.runtime.pipe.schedule.PipeSchedule(micro_batches, stages, stage_id)[source]¶
Directs the execution of a pipeline engine by generating sequences of
PipeInstruction
.Schedules are generators that yield sequences of
PipeInstruction
to process the micro-batches in one batch. Each yielded step is atomic in the sense that a barrier synchronization can be placed between successive steps without deadlock.Below is an example schedule that implements data parallelism with gradient accumulation:
class DataParallelSchedule(PipeSchedule): def steps(self): for step_id in range(self.micro_batches): cmds = [ LoadMicroBatch(buffer_id=0), ForwardPass(buffer_id=0), BackwardPass(buffer_id=0), ] if step_id == self.micro_batches - 1: cmds.extend([ ReduceGrads(), OptimizerStep(), ]) yield cmds def num_pipe_buffers(self): return 1
- Parameters
micro_batches (int) – The number of micro-batches that comprise a batch.
stages (int) – The number of pipeline stages.
stage_id (int) – The pipe stage that will execute the generated schedule.
- abstract steps()[source]¶
Yield a list of
PipeInstruction
for each step in the schedule.Note
Schedules must implement
steps()
to define the schedule.- Returns
Instructions to be executed as one step of the pipeline
- num_pipe_buffers()[source]¶
The number of pipeline buffers that will be used by this stage.
Note
Schedules should specialize
num_pipe_buffers()
for memory savings at scale.- Returns
The number of buffers for the engine to allocate.
- property stage¶
Stage index used to configure this schedule.
- property num_stages¶
The number of total pipeline stages used to configure this schedule.
- property num_micro_batches¶
The number of total micro_batches used to configure this schedule.
- property is_first_stage¶
True if the configured
stage_id
is the first stage in the pipeline.
- property is_last_stage¶
True if the configured
stage_id
is the last stage in the pipeline.
- class deepspeed.runtime.pipe.schedule.InferenceSchedule(micro_batches, stages, stage_id)[source]¶
A schedule for inferencing batches using pipeline parallelism.
- class deepspeed.runtime.pipe.schedule.TrainSchedule(micro_batches, stages, stage_id)[source]¶
A schedule for training a batch using hybrid parallelism.
Pipeline parallelism is extracted through gradient accumulation and thus convergence follows that of a data parallel approach with the same batch size.
- num_pipe_buffers()[source]¶
Return the number of pipeline buffers required for this stage.
This is equivalent to the maximum number of in-flight forward passes, since we need to remember the activations of forward passes in order to run backpropagation. For synchronous 1F1B, this is equivalent to the index difference between this stage and the last stage.
- class deepspeed.runtime.pipe.schedule.DataParallelSchedule(micro_batches, stages, stage_id)[source]¶
An example schedule that trains using traditional data parallelism with gradient accumulation.
- class deepspeed.runtime.pipe.schedule.PipeInstruction(**kwargs)[source]¶
Base class for all instructions to be executed by the pipeline engine.
All keyword arguments are stored as members similar to a
namedtuple
. These are then accessible to thePipeEngine
during execution.- Parameters
kwargs (optional) – keyword arguments to store as members
- class deepspeed.runtime.pipe.schedule.OptimizerStep(**kwargs)[source]¶
Performs one step with the optimizer and zeros gradients.
Note
Should be issued after
ReduceGrads
andReduceTiedGrads
.Note
Can be a synchronization point among data-parallel ranks.
- class deepspeed.runtime.pipe.schedule.ReduceGrads(**kwargs)[source]¶
Reduce the computed gradients among data-parallel processes within the stage.
- class deepspeed.runtime.pipe.schedule.ReduceTiedGrads(**kwargs)[source]¶
Reduce the computed gradients of tied modules within a pipeline-parallel group.
Warning
The stages included in this synchronization point are not known until the model is partitioned among pipeline stages. In the worst case, it includes all pipeline stages. This instruction should be scheduled carefully to avoid deadlocks.
- class deepspeed.runtime.pipe.schedule.BufferOpInstruction(buffer_id, **kwargs)[source]¶
A pipeline instruction that operates on pipeline buffer(s).
- Parameters
buffer_id (int) – the index of the pipeline buffer() to modify.
- class deepspeed.runtime.pipe.schedule.LoadMicroBatch(buffer_id, **kwargs)[source]¶
Load a micro-batch into a buffer.
Roughly:
buffers['inputs'][buffer_id] = next(data_iter)
- class deepspeed.runtime.pipe.schedule.ForwardPass(buffer_id, **kwargs)[source]¶
Compute a forward pass.
Roughly:
buffers['outputs'][buffer_id] = forward(buffers['inputs'][buffer_id])
- class deepspeed.runtime.pipe.schedule.BackwardPass(buffer_id, **kwargs)[source]¶
Compute a backward pass and accumulate gradients.
Roughly:
outputs = buffers['outputs'][buffer_id] gradients = buffers['gradients'][buffer_id] torch.autograd.backward(tensors=outputs, grad_tensors=gradients)
- class deepspeed.runtime.pipe.schedule.SendActivation(buffer_id, **kwargs)[source]¶
Send activations to the next stage in the pipeline.
Roughly:
send(buffers['outputs'][buffer_id])
Note
The communication is blocking and must be paired with a
RecvActivation
on the next pipeline stage to avoid deadlock.
- class deepspeed.runtime.pipe.schedule.RecvActivation(buffer_id, **kwargs)[source]¶
Receive activations from the previous stage in the pipeline.
Roughly:
buffers['inputs'][buffer_id] = recv()
Note
The communication is blocking and must be paired with a
SendActivation
on the previous pipeline stage to avoid deadlock.
- class deepspeed.runtime.pipe.schedule.SendGrad(buffer_id, **kwargs)[source]¶
Send computed gradients to the previous pipeline stage. with respect to the received activations
Note
Only received tensors with
requires_grad==True
will produce gradients. Missing gradients will be replaced withNone
on the receiving stage.Note
The communication is blocking and must be paired with a
RecvGrad
on the previous pipeline stage to avoid deadlock.
- class deepspeed.runtime.pipe.schedule.RecvGrad(buffer_id, **kwargs)[source]¶
Receive computed gradients the next pipeline stage.
Note
Only activations with
requires_grad==True
will produce gradients. Missing gradients will be replaced withNone
.Note
The communication is blocking and must be paired with a
SendGrad
on the next pipeline stage to avoid deadlock.