Mixture of Experts (MoE)

Layer specification

class deepspeed.moe.layer.MoE(hidden_size: int, expert: Module, num_experts: int = 1, ep_size: int = 1, k: int = 1, capacity_factor: float = 1.0, eval_capacity_factor: float = 1.0, min_capacity: int = 4, use_residual: bool = False, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, use_tutel: bool = False, enable_expert_tensor_parallelism: bool = False, top2_2nd_expert_sampling: bool = True)[source]

Initialize an MoE layer.

Parameters
  • hidden_size (int) – the hidden dimension of the model, importantly this is also the input and output dimension.

  • expert (nn.Module) – the torch module that defines the expert (e.g., MLP, torch.linear).

  • num_experts (int, optional) – default=1, the total number of experts per layer.

  • ep_size (int, optional) – default=1, number of ranks in the expert parallel world or group.

  • k (int, optional) – default=1, top-k gating value, only supports k=1 or k=2.

  • capacity_factor (float, optional) – default=1.0, the capacity of the expert at training time.

  • eval_capacity_factor (float, optional) – default=1.0, the capacity of the expert at eval time.

  • min_capacity (int, optional) – default=4, the minimum capacity per expert regardless of the capacity_factor.

  • use_residual (bool, optional) – default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer.

  • noisy_gate_policy (str, optional) – default=None, noisy gate policy, valid options are ‘Jitter’, ‘RSample’ or ‘None’.

  • drop_tokens (bool, optional) – default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).

  • use_rts (bool, optional) – default=True, whether to use Random Token Selection.

  • use_tutel (bool, optional) – default=False, whether to use Tutel optimizations (if installed).

  • enable_expert_tensor_parallelism (bool, optional) – default=False, whether to use tensor parallelism for experts

  • top2_2nd_expert_sampling (bool, optional) – default=True, whether to perform sampling for 2nd expert

forward(hidden_states: Tensor, used_token: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor][source]

MoE forward

Parameters
  • hidden_states (Tensor) – input to the layer

  • used_token (Tensor, optional) – default: None, mask only used tokens

Returns

A tuple including output, gate loss, and expert count.

  • output (Tensor): output of the model

  • l_aux (Tensor): gate loss value

  • exp_counts (Tensor): expert count