"""
Copyright 2021 The Microsoft DeepSpeed Team
DeepSpeed Communication Package: deepspeed.comm
deepspeed.comm
-- import and use deepspeeed.ops.comm
-- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
-- use custom collectives
-- can either use torch.dist or ds.ops.comm?
Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
deepspeed.comm API
-- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
-- e.g. if a client code used
from deepspeed import comm as dist
instead of
import torch.distributed as dist
The code should work without breaking any of the public torch.distributed functionality
Future:
-- deepspeed groups API should be brought into ds.comm
"""
from enum import Enum
import torch
import os
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
from .constants import *
from deepspeed.accelerator import get_accelerator
class ReduceOp(Enum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BAND = 4
BOR = 5
BXOR = 6
AVG = 7
UNUSED = 8
from deepspeed.utils.comms_logging import CommsLogger
from deepspeed.utils import timer, get_caller_func
from deepspeed.comm.torch import TorchBackend
from deepspeed import utils
from datetime import timedelta
# Current deepspeed.comm backend (cdb) global object for simple access by client code
use_ds_backend = False
cdb = None
# Create global timer for ops
timers = timer.SynchronizedWallClockTimer()
timer_summary = {}
comms_logger = CommsLogger()
# Maintain objects of all initialized ds backends and assign them using the API functions in this file
nccl_backend = None
mpi_backend = None
# This should be set here so all rank/size information from the launcher can be propagated
from deepspeed.comm.utils import *
def _configure_using_config_file(config):
if config.comms_logger_enabled:
comms_logger.configure(config)
def configure(
deepspeed_config=None,
enabled=None,
prof_all=None,
prof_ops=None,
verbose=None,
debug=None,
):
if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config.comms_config)
if enabled is not None:
comms_logger.enabled = enabled
if prof_all is not None:
comms_logger.prof_all = prof_all
if prof_ops is not None:
comms_logger.prof_ops = prof_ops
if verbose is not None:
comms_logger.verbose = verbose
if debug is not None:
comms_logger.debug = debug
# Logging wrapper for timing ops
def timed_op(func):
def log_wrapper(*args, **kwargs):
# Add enabled flag so that overhead to each comm op is two if conditions at most
if comms_logger.enabled:
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
# Need func args for their defaults
func_args = get_default_args(func)
func_args.update(kwargs)
msg_size = get_msg_size_from_args(func, *args, **kwargs)
log_name = get_debug_log_name(func_args, comms_logger.debug)
timers(log_name).start()
# Return the op, then stop the op's timer
try:
return func(*args, **kwargs)
finally:
if comms_logger.enabled:
# Need to make op blocking for accurate logging
get_accelerator().synchronize()
# If we're using MPI, we can't simply sync the stream
if cdb.using_mpi:
cdb.barrier()
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
log_name = get_debug_log_name(func_args, comms_logger.debug)
raw_name = func.__name__
timers(log_name).stop()
# need temp var since 'elapsed' resets events
time_elapsed = timers(log_name).elapsed(reset=False)
comms_logger.append(raw_name, log_name, time_elapsed, msg_size)
return log_wrapper
# For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code.
# DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation.
# Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html
# UNUSED: Future helper function to initialize DS backends
def init_deepspeed_backend(ds_backend):
global cdb
global nccl_backend
global mpi_backend
global use_ds_backend
if ds_backend == NCCL_BACKEND:
utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
elif ds_backend == MPI_BACKEND:
utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
elif ds_backend == GLOO_BACKEND:
utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
else:
utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")
def is_initialized():
#assert cdb is not None, 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb is None:
return False
else:
return cdb.is_initialized()
def destroy_process_group(group=None):
global cdb
return cdb.destroy_process_group(group=group)
def new_group(ranks):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.new_group(ranks)
def is_available() -> bool:
# Returns ``True`` if the deepspeed comm package is available.
# TODO: load other ops. Clients including deepspeed itself should use deepspeed.comm to import
# any communication related primitives from this package.
# use hasattr(deepspeed.csrc.ops, "_comm") or something
return True
def set_backend(backend):
if not use_ds_backend:
utils.logger.error(
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
)
raise RuntimeError(
'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
)
global cdb
global nccl_backend
global mpi_backend
try:
if backend_name == NCCL_BACKEND:
if nccl_backend is not None and nccl_backend.is_initialized():
cdb = nccl_backend
elif backend_name == MPI_BACKEND:
if mpi_backend is not None and mpi_backend.is_initialized():
cdb = mpi_backend
except Exception as inst:
print(inst)
@timed_op
def broadcast(tensor,
src,
group=None,
async_op=False,
prof=False,
log_name='broadcast',
debug=get_caller_func()):
global cdb
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
@timed_op
def all_gather(tensor_list,
tensor,
group=None,
async_op=False,
prof=False,
log_name='all_gather',
debug=get_caller_func()):
global cdb
return cdb.all_gather(tensor_list=tensor_list,
tensor=tensor,
group=group,
async_op=async_op)
def has_reduce_scatter_base():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
return cdb.has_reduce_scatter_base
def reduce_scatter_fn(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
debug=get_caller_func()):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_reduce_scatter_base:
return reduce_scatter_base(output_tensor,
tensor,
op=op,
group=group,
async_op=async_op,
prof=prof,
debug=debug)
else:
utils.logger.warning_once(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
return reduce_scatter(output_tensor,
input_tensor_lst,
op=op,
group=group,
async_op=async_op,
prof=prof,
debug=debug)
@timed_op
def reduce_scatter_base(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce_scatter_base',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter_base(output_tensor=output_tensor,
input_tensor=tensor,
op=op,
group=group,
async_op=async_op)
@timed_op
def all_gather_base(output_tensor,
tensor,
group=None,
async_op=False,
prof=False,
log_name='all_gather_base',
debug=get_caller_func()):
global cdb
return cdb.all_gather_base(output_tensor=output_tensor,
input_tensor=tensor,
group=group,
async_op=async_op)
def has_allgather_base():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
return cdb.has_allgather_base
def allgather_fn(output_tensor,
input_tensor,
group=None,
async_op=False,
debug=get_caller_func()):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_allgather_base:
return all_gather_base(output_tensor,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
else:
if get_rank() == 0:
utils.logger.warning_once(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
return all_gather(output_tensors,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
@timed_op
def all_to_all_single(output,
tensor,
output_split_sizes=None,
input_split_sizes=None,
group=None,
async_op=False,
prof=False,
log_name='all_to_all_single',
debug=get_caller_func()):
global cdb
return cdb.all_to_all_single(output=output,
input=tensor,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op)
@timed_op
def send(tensor,
dst,
group=None,
tag=0,
prof=False,
log_name='send',
debug=get_caller_func()):
global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op
def recv(tensor,
src=None,
group=None,
tag=0,
prof=False,
log_name='recv',
debug=get_caller_func()):
global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
@timed_op
def isend(tensor,
dst,
group=None,
tag=0,
prof=False,
log_name='isend',
debug=get_caller_func()):
global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op
def irecv(tensor,
src=None,
group=None,
tag=0,
prof=False,
log_name='irecv',
debug=get_caller_func()):
global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
@timed_op
def gather(tensor,
gather_list=None,
dst=0,
group=None,
async_op=False,
prof=False,
log_name='gather',
debug=get_caller_func()):
global cdb
return cdb.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)
@timed_op
def scatter(tensor,
scatter_list=None,
src=0,
group=None,
async_op=False,
prof=False,
log_name='scatter',
debug=get_caller_func()):
global cdb
return cdb.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)
@timed_op
def barrier(group=None,
async_op=False,
device_ids=None,
prof=False,
log_name='barrier',
debug=get_caller_func()):
global cdb
return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
@timed_op
def monitored_barrier(group=None,
timeout=None,
wait_all_ranks=False,
prof=False,
log_name='monitored_barrier',
debug=get_caller_func()):
global cdb
return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
def log_summary():
global cdb
barrier(log_name='log_summary_barrier')
if cdb.get_rank() == 0:
comms_logger.log_all()
barrier(log_name='log_summary_barrier')
@timed_op
def reduce(tensor,
dst,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce',
debug=get_caller_func()):
global cdb
return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
@timed_op
def reduce_scatter(output,
input_list,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce_scatter',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter(output=output,
input_list=input_list,
op=op,
group=group,
async_op=async_op)
@timed_op
def all_reduce(tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='all_reduce',
debug=get_caller_func()):
#if profile_comm:
# context of the timers?
# timers.start()
# TensorBoard logging for comm calls.?
global cdb
#print(f'op = {op}, cdb= {cdb.name}')
return cdb.all_reduce(tensor, op, group, async_op)
def get_world_group():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_group()
def get_world_size(group=None) -> int:
"""
Returns the number of processes in the current process group
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
The world size of the process group
-1, if not part of the group
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_size(group)
def get_rank(group=None):
"""
Returns the rank of the current process in the provided ``group`` or the
default group if none was provided.
Rank is a unique identifier assigned to each process within a distributed
process group. They are always consecutive integers ranging from 0 to
``world_size``.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
The rank of the process group
-1, if not part of the group
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_rank(group)
def get_local_rank():
"""
Helper function to get local rank after a backend has been set and initialized
Args:
None
Returns:
local rank (= GPU device ID)
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return get_local_rank_from_launcher()
def get_global_rank(group=None, group_rank=0):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_global_rank(group, group_rank)
# Main DeepSpeed Comms. public API.
[docs]def init_distributed(dist_backend=None,
auto_mpi_discovery=True,
distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
verbose=True,
timeout=default_pg_timeout,
init_method=None,
dist_init_required=None,
config=None,
rank=-1,
world_size=-1):
''' Initialize dist backend, potentially performing MPI discovery if needed
Arguments:
dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
distributed_port: Optional (int). torch distributed backend port
verbose: Optional (bool). verbose logging
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
'''
global cdb
configure(deepspeed_config=config)
if dist_init_required is None:
dist_init_required = cdb is None or not cdb.is_initialized()
if cdb is None and torch.distributed.is_initialized():
# The user initialized torch.dist themselves, create cdb and short-circuit
cdb = TorchBackend(dist_backend, timeout, init_method)
return
if dist_init_required is False:
assert (
cdb is not None and cdb.is_initialized() is True
), "Distributed backend is not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
else:
# Initialize torch distributed if needed
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
if verbose:
utils.logger.info(
"Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
)
if in_aml() and not in_dlts():
patch_aml_env_for_torch_nccl_backend(verbose=verbose)
elif in_aws_sm():
patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
else:
mpi_discovery(distributed_port=distributed_port, verbose=verbose)
if cdb is not None and cdb.is_initialized():
if int(os.getenv('RANK', '0')) == 0:
utils.logger.info('Distributed backend already initialized')
else:
assert isinstance(timeout, timedelta)
if dist_backend == None:
dist_backend = get_accelerator().communication_backend_name()
if int(os.getenv('RANK', '0')) == 0:
utils.logger.info(
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
dist_backend))
# Create a torch backend object, initialize torch distributed, and assign to cdb
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
'''
Discovery MPI environment via mpi4py and map to relevant dist state
'''
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(distributed_port)
if verbose:
utils.logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if cdb is not None and cdb.is_initialized():
assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
rank, cdb.get_rank())
assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, cdb.get_world_size())
def in_aml():
# Are we running inside an Azure Machine Learning (AML) environment?
return 'AZUREML_EXPERIMENT_ID' in os.environ
def in_aws_sm():
# Are we running inside an AWS SageMaker environment?
return 'SM_TRAINING_ENV' in os.environ
def in_dlts():
# Are we running on a DLTS cluster?
return 'DLTS_JOB_ID' in os.environ
def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
"""Helper routine to get and set environment variables.
This is adapted from Azure ML's documentation available from:
https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
os.environ["WORLD_SIZE"])
if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
os.environ["MASTER_ADDR"] = master_node_params[0]
# Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = str(master_port)
else:
os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
if verbose:
utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
if verbose:
utils.logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
"""Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
if verbose:
utils.logger.info(
"Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))