The activation checkpointing API’s in DeepSpeed can be used to enable a range of memory optimizations relating to activation checkpointing. These include activation partitioning across GPUs when using model parallelism, CPU checkpointing, contiguous memory optimizations, etc.
Please see the DeepSpeed JSON config for the full set.
Here we present the activation checkpointing API. Please see the enabling DeepSpeed for Megatron-LM tutorial for example usage.
Configuring Activation Checkpointing¶
configure(mpu_, deepspeed_config=None, partition_activations=None, contiguous_checkpointing=None, num_checkpoints=None, checkpoint_in_cpu=None, synchronize=None, profile=None)¶
Configure DeepSpeed Activation Checkpointing.
- mpu – Optional: An object that implements the following methods get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
- deepspeed_config – Optional: DeepSpeed Config json file when provided will be used to configure DeepSpeed Activation Checkpointing
- partition_activations – Optional: Partitions activation checkpoint across model parallel GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
- contiguous_checkpointing – Optional: Copies activation checkpoints to a contiguous memory buffer. Works only with homogeneous checkpoints when partition_activations is enabled. Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if provided
- num_checkpoints – Optional: Number of activation checkpoints stored during the forward propagation of the model. Used to calculate the buffer size for contiguous_checkpointing Will overwrite deepspeed_config if provided
- checkpoint_in_cpu – Optional: Moves the activation checkpoint to CPU. Only works with partition_activation. Default is false. Will overwrite deepspeed_config if provided
- synchronize – Optional: Performs torch.cuda.synchronize() at the beginning and end of each call to deepspeed.checkpointing.checkpoint for both forward and backward pass. By default false. Will overwrite deepspeed_config if provided
- profile – Optional: Logs the forward and backward time for each deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config if provided
- True if deepspeed activation checkpointing has been configured
- by calling deepspeed.checkpointing.configure, else returns false
Parameters: None – Returns: True of configured, else False
Using Activation Checkpointing¶
Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.
Resets memory buffers related to contiguous memory optimizations. Should be called during eval when multiple forward propagations are computed without any backward propagation that usually clears these buffers. :param None:
Configuring and Checkpointing Random Seeds¶
Get cuda rng tracker.
Initialize model parallel cuda seed.
This function should be called after the model parallel is initialized. Also, no torch.cuda.manual_seed should be called after this function. Basically, this is replacement for that function. Two set of RNG states are tracked:
- default state: This is for data parallelism and is the same among a
- set of model parallel GPUs but different across different model paralle groups. This is used for example for dropout in the non-model-parallel regions.
- model-parallel state: This state is different among a set of model
- parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.
Tracker for the cuda RNG states.
Using the add method, a cuda rng state is initialized based on the input seed and is assigned to name. Later, by forking the rng state, we can perform operations and return to our starting cuda state.
This function is adapted from torch.utils.checkpoint with two main changes:
- torch.cuda.set_rng_state is replaced with _set_cuda_rng_state
- the states in the model parallel tracker are also properly tracked/set/reset.
- Performance activation partitioning, contiguous memory optimization
- CPU Checkpointing
- Profile forward and backward functions