Inference Setup

The entrypoint for inference with DeepSpeed is deepspeed.init_inference().

Example usage:

engine = deepspeed.init_inference(model=net, config=config)

The DeepSpeedInferenceConfig is used to control all aspects of initializing the InferenceEngine. The config should be passed as a dictionary to init_inference, but parameters can also be passed as keyword arguments.

class deepspeed.inference.config.DeepSpeedInferenceConfig[source]

Sets parameters for DeepSpeed Inference Engine.

replace_with_kernel_inject: bool = False (alias 'kernel_inject')

Set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise, the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection)

dtype: DtypeEnum = torch.float16

Desired model data type, will convert model to this type. Supported target types: torch.half, torch.int8, torch.float

tensor_parallel: DeepSpeedTPConfig = {} (alias 'tp')

Configuration for tensor parallelism used to split the model across several GPUs. Expects a dictionary containing values for DeepSpeedTPConfig.

enable_cuda_graph: bool = False

Use this flag for capturing the CUDA-Graph of the inference ops, so that it can run faster using the graph replay method.

use_triton: bool = False

Use this flag to use triton kernels for inference ops.

triton_autotune: bool = False

Use this flag to enable triton autotuning. Turning it on is better for performance but increase the 1st runtime for autotuning.

zero: DeepSpeedZeroConfig = {}

ZeRO configuration to use with the Inference Engine. Expects a dictionary containing values for DeepSpeedZeroConfig.

triangular_masking: bool = True (alias 'tm')

Controls the type of masking for attention scores in transformer layer. Note that the masking is application specific.

moe: Union[bool, DeepSpeedMoEConfig] = {}

Specify if the type of Transformer is MoE. Expects a dictionary containing values for DeepSpeedMoEConfig.

quant: QuantizationConfig = {}

NOTE: only works for int8 dtype. Quantization settings used for quantizing your model using the MoQ. The setting can be one element or a tuple. If one value is passed in, we consider it as the number of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for all the network except the MLP part that we use 8 extra grouping). Expects a dictionary containing values for QuantizationConfig.

checkpoint: Union[str, Dict] = None

Path to deepspeed compatible checkpoint or path to JSON with load policy.

base_dir: str = ''

This shows the root directory under which all the checkpoint files exists. This can be passed through the json config too.

set_empty_params: bool = False

specifying whether the inference-module is created with empty or real Tensor

save_mp_checkpoint_path: str = None

The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the parallelism degree to help alleviate the model loading overhead. It does not save any new checkpoint if no path is passed.

checkpoint_config: InferenceCheckpointConfig = {} (alias 'ckpt_config')

TODO: Add docs. Expects a dictionary containing values for InferenceCheckpointConfig.

return_tuple: bool = True

Specify whether or not the transformer layers need to return a tuple or a Tensor.

training_mp_size: int = 1

If loading a checkpoint this is the mp size that it was trained with, it may be different than what the mp size that you want to use during inference.

replace_method: str = 'auto'
injection_policy: Dict = None (alias 'injection_dict')

Dictionary mapping a client nn.Module to its corresponding injection policy. e.g., {BertLayer : deepspeed.inference.HFBertLayerPolicy}

injection_policy_tuple: tuple = None

TODO: Add docs

config: Dict = None (alias 'args')
max_out_tokens: int = 1024 (alias 'max_tokens')

This argument shows the maximum number of tokens inference-engine can work with, including the input and output tokens. Please consider increasing it to the required token-length required for your use-case.

min_out_tokens: int = 1 (alias 'min_tokens')

This argument communicates to the runtime the minimum number of tokens you expect you will need to generate. This will cause the runtime to error if it unable to provide this and provide context on the memory pressure rather than seg-faulting or providing corrupted output.

transposed_mode: bool = False
mp_size: int = 1

Desired model parallel size, default is 1 meaning no model parallelism. Deprecated, please use the ``tensor_parallel` config to control model parallelism.

mpu: object = None
ep_size: int = 1
ep_group: object = None (alias 'expert_group')
ep_mp_group: object = None (alias 'expert_mp_group')
moe_experts: list = [1]
moe_type: MoETypeEnum = MoETypeEnum.standard
class deepspeed.inference.config.DeepSpeedTPConfig[source]

Configure tensor parallelism settings

enabled: bool = True

Turn tensor parallelism on/off.

tp_size: int = 1

Number of devices to split the model across using tensor parallelism.

mpu: object = None

A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}().

tp_group: object = None
class deepspeed.inference.config.DeepSpeedMoEConfig[source]

Sets parameters for MoE

enabled: bool = True
ep_size: int = 1

The expert-parallelism size which is used for partitioning the experts across the GPUs in the expert-parallel group.

moe_experts: list = [1] (alias 'num_experts')

The global number of experts used in an MoE layer.

type: MoETypeEnum = MoETypeEnum.standard

Specify the type of MoE layer. We have two types of MoE layer: ‘Standard’ and ‘Residual’.

ep_mp_group: object = None
ep_group: object = None (alias 'expert_group')
class deepspeed.inference.config.QuantizationConfig[source]
enabled: bool = True
activation: ActivationQuantConfig = ActivationQuantConfig(q_type='symmetric', q_groups=1, enabled=True, num_bits=8)
weight: WeightQuantConfig = WeightQuantConfig(q_type='symmetric', q_groups=1, enabled=True, num_bits=8, quantized_initialization={}, post_init_quant={})
qkv: QKVQuantConfig = QKVQuantConfig(enabled=True)
class deepspeed.inference.config.InferenceCheckpointConfig[source]
checkpoint_dir: str = None
save_mp_checkpoint_path: str = None
base_dir: str = None

Example config:

config = {
    "kernel_inject": True,
    "tensor_parallel": {"tp_size": 4},
    "dtype": "fp16",
    "enable_cuda_graph": False
}
deepspeed.init_inference(model, config=None, **kwargs)[source]

Initialize the DeepSpeed InferenceEngine.

Description: all four cases are valid and supported in DS init_inference() API.

# Case 1: user provides no config and no kwargs. Default config will be used.

generator.model = deepspeed.init_inference(generator.model)
string = generator("DeepSpeed is")
print(string)

# Case 2: user provides a config and no kwargs. User supplied config will be used.

generator.model = deepspeed.init_inference(generator.model, config=config)
string = generator("DeepSpeed is")
print(string)

# Case 3: user provides no config and uses keyword arguments (kwargs) only.

generator.model = deepspeed.init_inference(generator.model,
                                            tensor_parallel={"tp_size": world_size},
                                            dtype=torch.half,
                                            replace_with_kernel_inject=True)
string = generator("DeepSpeed is")
print(string)

# Case 4: user provides config and keyword arguments (kwargs). Both config and kwargs are merged and kwargs take precedence.

generator.model = deepspeed.init_inference(generator.model, config={"dtype": torch.half}, replace_with_kernel_inject=True)
string = generator("DeepSpeed is")
print(string)
Parameters
  • model – Required: original nn.module object without any wrappers

  • config – Optional: instead of arguments, you can pass in a DS inference config dict or path to JSON file

Returns

A deepspeed.InferenceEngine wrapped model.