diff --git a/requirements-test.txt b/requirements-test.txt index 08d9fa8ae5..2a05c9f568 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -20,3 +20,6 @@ git+git://github.com/rwightman/pytorch-image-models@v0.4.5#egg=timm # Dependency for factory hydra-core >= 1.1 + +# Dependency for Mixture of Experts +fairscale >= 0.4.5 diff --git a/tests/test_block_factory.py b/tests/test_block_factory.py index 24fa0bd7b8..38955eace0 100644 --- a/tests/test_block_factory.py +++ b/tests/test_block_factory.py @@ -17,7 +17,7 @@ xFormerEncoderConfig, ) -BATCH = 20 +BATCH = 4 SEQ = 128 MODEL = 96 DROPOUT = 0.5 @@ -77,6 +77,8 @@ def test_xformer_encoder_block( "dropout": DROPOUT, "activation": activation, "hidden_layer_multiplier": 4, + "number_of_experts": 4, + "gate_config": "top_2", } position_encoding_config = { @@ -167,6 +169,8 @@ def test_xformer_decoder_block( "dropout": DROPOUT, "activation": activation, "hidden_layer_multiplier": 4, + "number_of_experts": 4, + "gate_config": "top_2", } position_encoding_config = { diff --git a/tests/test_feedforward.py b/tests/test_feedforward.py index 103bfe511f..9758d158b1 100644 --- a/tests/test_feedforward.py +++ b/tests/test_feedforward.py @@ -8,8 +8,9 @@ from xformers.components import Activation from xformers.components.feedforward import FEEDFORWARD_REGISTRY, build_feedforward +from xformers.components.feedforward.mixture_of_experts import GateConfig -BATCH = 20 +BATCH = 4 SEQ = 512 EMBD = 16 LATENT = 128 @@ -34,6 +35,8 @@ def test_feedforward( "dropout": DROPOUT, "activation": activation, "hidden_layer_multiplier": 4, + "number_of_experts": 4, # MoE + "gate_config": "top_2", # MoE } # dummy, just check construction and dimensions in the FW pass @@ -47,3 +50,35 @@ def test_feedforward( ffw = ffw.to(device) _ = ffw(inputs) + + +def get_expert(): + return torch.nn.Linear(LATENT, LATENT, bias=False) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires CUDA") +@pytest.mark.parametrize("gate", [g.value for g in GateConfig]) +@pytest.mark.parametrize("number_of_local_experts", [None, 4]) +@pytest.mark.parametrize("expert_constructor", [None, get_expert]) +def test_moe(gate, number_of_local_experts, expert_constructor): + test_config = { + "name": "MixtureOfExperts", + "dim_model": LATENT, + "dropout": DROPOUT, + "activation": Activation.ReLU, + "hidden_layer_multiplier": 4, + "number_of_experts": 4, + "number_of_local_experts": number_of_local_experts, + "gate_config": gate, + "expert_constructor": expert_constructor, + } + + # dummy, just check construction and dimensions in the FW pass + ffw = build_feedforward(test_config) + + inputs = torch.rand(BATCH, SEQ, LATENT, device=torch.device("cuda")) + ffw = ffw.to(torch.device("cuda")) + + outputs = ffw(inputs) + loss = torch.sum(outputs) + loss.backward() diff --git a/tests/test_model_factory.py b/tests/test_model_factory.py index 247c8ceb81..8f34d84215 100644 --- a/tests/test_model_factory.py +++ b/tests/test_model_factory.py @@ -49,6 +49,8 @@ "activation": "relu", "hidden_layer_multiplier": 4, "dim_model": EMB, + "number_of_experts": 4, + "gate_config": "top_2", }, } diff --git a/xformers/components/feedforward/fused_mlp.py b/xformers/components/feedforward/fused_mlp.py index 0eb41e2dd3..dcc35028b1 100644 --- a/xformers/components/feedforward/fused_mlp.py +++ b/xformers/components/feedforward/fused_mlp.py @@ -29,8 +29,6 @@ class FusedMlpConfig(FeedforwardConfig): class FusedMLP(Feedforward): """ A MLP using fused linear layers. - - .. warning: This is not currently competitive with PyTorch in terms of training speed """ def __init__( diff --git a/xformers/components/feedforward/mixture_of_experts.py b/xformers/components/feedforward/mixture_of_experts.py new file mode 100644 index 0000000000..07e2e96488 --- /dev/null +++ b/xformers/components/feedforward/mixture_of_experts.py @@ -0,0 +1,163 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import tempfile +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Optional + +import torch + +from xformers.components import Activation +from xformers.components.feedforward import ( + Feedforward, + FeedforwardConfig, + register_feedforward, +) + +_is_fairscale_available = True + +try: + import torch.distributed as dist + from fairscale.nn import MOELayer, Top2Gate + + from xformers.components.feedforward import MLP + +except ImportError: + logging.warning( + "Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed." + " Please install them if you would like to use MoE" + ) + _is_fairscale_available = False + + +if _is_fairscale_available: + + # Credits: initially implemented in FairScale for sanity checking + class RoundRobinGate(torch.nn.Module): + def __init__(self, model_dim, num_experts): + super().__init__() + self.model_dim = model_dim + self.num_experts = num_experts + + def forward(self, input): + s = input.shape[0] + assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0" + capacity = 2 * s // self.num_experts + output = torch.zeros( + s, self.num_experts, capacity, dtype=input.dtype, device=input.device + ) + for i in range(s): + output[i, i % self.num_experts, i // self.num_experts] = 1.0 + return 0.0, output, output.bool() + + class GateConfig(str, Enum): + RoundRobin = "round_robin" + Top2 = "top_2" + # Other gating techniques could be exposed here + + @dataclass + class MoEConfig(FeedforwardConfig): + number_of_experts: int + gate_config: GateConfig + number_of_local_experts: Optional[int] = None + expert_constructor: Optional[Any] = None + hidden_layer_multiplier: Optional[int] = None + group: Optional[Any] = None + + @register_feedforward("MixtureOfExperts", MoEConfig) + class MixtureOfExperts(Feedforward): + """ + A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_. + xFormers uses the FairScale_ implementation under the hood. + + .. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt + + .. _Gshard: https://arxiv.org/pdf/2006.16668.pdf + .. _FairScale: https://github.com/facebookresearch/fairscale/ + """ + + def __init__( + self, + dim_model: int, + dropout: float, + activation: Activation, + number_of_experts: int, + gate_config: GateConfig, + number_of_local_experts: Optional[int] = None, + expert_constructor: Optional[Callable[[], torch.nn.Module]] = None, + hidden_layer_multiplier: Optional[int] = None, + group: Optional[Any] = None, + *_, + **__, + ): + super().__init__() + + # Handle a possibly uninitialized process group + if group is None and not dist.is_initialized(): + logging.warning( + "Torch Distributed is not initialized, please do so before instantiating MoE" + ) + logging.warning("Attempting fallback initialization") + + init_url = "file://" + tempfile.mkstemp()[1] + backend = ( + dist.Backend.NCCL + if torch.cuda.is_available() + else dist.Backend.GLOO + ) + dist.init_process_group( + backend=backend, + rank=0, + world_size=1, + init_method=init_url, + ) + + if number_of_local_experts is not None: + assert number_of_experts >= number_of_local_experts + else: + if dist.get_world_size() == 1: + logging.warning("Local experts no specified but world size of 1") + logging.warning("Assuming that all experts are local") + number_of_local_experts = number_of_experts + else: + number_of_local_experts = 1 + + # Programatically handle the gating technique + gate_constructor = { + GateConfig.RoundRobin: RoundRobinGate, + GateConfig.Top2: Top2Gate, + }[gate_config] + + self.gate = gate_constructor(dim_model, number_of_experts) + + # Programatically handle the experts + if expert_constructor is None: + + multiplier = ( + hidden_layer_multiplier + if hidden_layer_multiplier is not None + else 4 + ) + + def expert_constructor() -> torch.nn.Module: + return MLP(dim_model, dropout, activation, multiplier) + + assert expert_constructor is not None + + local_experts = torch.nn.ModuleList( + [expert_constructor() for _ in range(number_of_local_experts)] + ) + + self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group) + + self.requires_cuda = True + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + # FairScale MoE assumes that the dimensions are [S, B, E] + # xFormers assumes [B, S, E] + return self.moe(inputs.movedim(0, 1)).movedim(0, 1)