The activation checkpointing API’s in DeepSpeed can be used to enable a range of memory optimizations relating to activation checkpointing. These include activation partitioning across GPUs when using model parallelism, CPU checkpointing, contiguous memory optimizations, etc.
Please see the DeepSpeed JSON config for the full set.
Here we present the activation checkpointing API. Please see the enabling DeepSpeed for Megatron-LM tutorial for example usage.
Configuring Activation Checkpointing¶
- deepspeed.checkpointing.configure(mpu_, deepspeed_config=None, partition_activations=None, contiguous_checkpointing=None, num_checkpoints=None, checkpoint_in_cpu=None, synchronize=None, profile=None)¶
Configure DeepSpeed Activation Checkpointing.
mpu – Optional: An object that implements the following methods get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
deepspeed_config – Optional: DeepSpeed Config json file when provided will be used to configure DeepSpeed Activation Checkpointing
partition_activations – Optional: Partitions activation checkpoint across model parallel GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
contiguous_checkpointing – Optional: Copies activation checkpoints to a contiguous memory buffer. Works only with homogeneous checkpoints when partition_activations is enabled. Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if provided
num_checkpoints – Optional: Number of activation checkpoints stored during the forward propagation of the model. Used to calculate the buffer size for contiguous_checkpointing Will overwrite deepspeed_config if provided
checkpoint_in_cpu – Optional: Moves the activation checkpoint to CPU. Only works with partition_activation. Default is false. Will overwrite deepspeed_config if provided
synchronize – Optional: Performs torch.cuda.synchronize() at the beginning and end of each call to deepspeed.checkpointing.checkpoint for both forward and backward pass. By default false. Will overwrite deepspeed_config if provided
profile – Optional: Logs the forward and backward time for each deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config if provided
- True if deepspeed activation checkpointing has been configured
by calling deepspeed.checkpointing.configure, else returns false
True of configured, else False
Using Activation Checkpointing¶
- deepspeed.checkpointing.checkpoint(function, *args)¶
Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.
Resets memory buffers related to contiguous memory optimizations. Should be called during eval when multiple forward propagations are computed without any backward propagation that usually clears these buffers. :param None:
Configuring and Checkpointing Random Seeds¶
Get cuda rng tracker.
Initialize model parallel cuda seed.
This function should be called after the model parallel is initialized. Also, no torch.cuda.manual_seed should be called after this function. Basically, this is replacement for that function. Two set of RNG states are tracked:
- default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across different model parallel groups. This is used for example for dropout in the non-model-parallel regions.
- model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.
- class deepspeed.checkpointing.CudaRNGStatesTracker¶
Tracker for the cuda RNG states.
Using the add method, a cuda rng state is initialized based on the input seed and is assigned to name. Later, by forking the rng state, we can perform operations and return to our starting cuda state.
- class deepspeed.checkpointing.CheckpointFunction(*args, **kwargs)¶
This function is adapted from torch.utils.checkpoint with two main changes:
torch.cuda.set_rng_state is replaced with _set_cuda_rng_state
the states in the model parallel tracker are also properly tracked/set/reset.
Performance activation partitioning, contiguous memory optimization
Profile forward and backward functions