-
Notifications
You must be signed in to change notification settings - Fork 617
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Running microGPT example ! Needs some proper testing
- Loading branch information
1 parent
04bb6c1
commit da276f4
Showing
8 changed files
with
212 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import logging | ||
import tempfile | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
|
||
from xformers.components import Activation | ||
from xformers.components.feedforward import ( | ||
Feedforward, | ||
FeedforwardConfig, | ||
register_feedforward, | ||
) | ||
|
||
_is_fairscale_available = True | ||
|
||
try: | ||
import torch.distributed as dist | ||
from fairscale.nn import MOELayer, Top2Gate | ||
|
||
from xformers.components.feedforward import MLP | ||
|
||
except ImportError: | ||
logging.warning( | ||
"Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed." | ||
" Please install them if you would like to use MoE" | ||
) | ||
_is_fairscale_available = False | ||
|
||
|
||
if _is_fairscale_available: | ||
|
||
# Credits: initially implemented in FairScale for sanity checking | ||
class RoundRobinGate(torch.nn.Module): | ||
def __init__(self, model_dim, num_experts): | ||
super().__init__() | ||
self.model_dim = model_dim | ||
self.num_experts = num_experts | ||
|
||
def forward(self, input): | ||
s = input.shape[0] | ||
assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0" | ||
capacity = 2 * s // self.num_experts | ||
output = torch.zeros( | ||
s, self.num_experts, capacity, dtype=input.dtype, device=input.device | ||
) | ||
for i in range(s): | ||
output[i, i % self.num_experts, i // self.num_experts] = 1.0 | ||
return 0.0, output, output.bool() | ||
|
||
class GateConfig(str, Enum): | ||
RoundRobin = "round_robin" | ||
Top2 = "top_2" | ||
# Other gating techniques could be exposed here | ||
|
||
@dataclass | ||
class MoEConfig(FeedforwardConfig): | ||
number_of_experts: int | ||
gate_config: GateConfig | ||
number_of_local_experts: Optional[int] = None | ||
expert_constructor: Optional[Any] = None | ||
hidden_layer_multiplier: Optional[int] = None | ||
group: Optional[Any] = None | ||
|
||
@register_feedforward("MixtureOfExperts", MoEConfig) | ||
class MixtureOfExperts(Feedforward): | ||
""" | ||
A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_. | ||
xFormers uses the FairScale_ implementation under the hood. | ||
.. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt | ||
.. _Gshard: https://arxiv.org/pdf/2006.16668.pdf | ||
.. _FairScale: https://github.com/facebookresearch/fairscale/ | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim_model: int, | ||
dropout: float, | ||
activation: Activation, | ||
number_of_experts: int, | ||
gate_config: GateConfig, | ||
number_of_local_experts: Optional[int] = None, | ||
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None, | ||
hidden_layer_multiplier: Optional[int] = None, | ||
group: Optional[Any] = None, | ||
*_, | ||
**__, | ||
): | ||
super().__init__() | ||
|
||
# Handle a possibly uninitialized process group | ||
if group is None and not dist.is_initialized(): | ||
logging.warning( | ||
"Torch Distributed is not initialized, please do so before instantiating MoE" | ||
) | ||
logging.warning("Attempting fallback initialization") | ||
|
||
init_url = "file://" + tempfile.mkstemp()[1] | ||
backend = ( | ||
dist.Backend.NCCL | ||
if torch.cuda.is_available() | ||
else dist.Backend.GLOO | ||
) | ||
dist.init_process_group( | ||
backend=backend, | ||
rank=0, | ||
world_size=1, | ||
init_method=init_url, | ||
) | ||
|
||
if number_of_local_experts is not None: | ||
assert number_of_experts >= number_of_local_experts | ||
else: | ||
if dist.get_world_size() == 1: | ||
logging.warning("Local experts no specified but world size of 1") | ||
logging.warning("Assuming that all experts are local") | ||
number_of_local_experts = number_of_experts | ||
else: | ||
number_of_local_experts = 1 | ||
|
||
# Programatically handle the gating technique | ||
gate_constructor = { | ||
GateConfig.RoundRobin: RoundRobinGate, | ||
GateConfig.Top2: Top2Gate, | ||
}[gate_config] | ||
|
||
self.gate = gate_constructor(dim_model, number_of_experts) | ||
|
||
# Programatically handle the experts | ||
if expert_constructor is None: | ||
|
||
multiplier = ( | ||
hidden_layer_multiplier | ||
if hidden_layer_multiplier is not None | ||
else 4 | ||
) | ||
|
||
def expert_constructor() -> torch.nn.Module: | ||
return MLP(dim_model, dropout, activation, multiplier) | ||
|
||
assert expert_constructor is not None | ||
|
||
local_experts = torch.nn.ModuleList( | ||
[expert_constructor() for _ in range(number_of_local_experts)] | ||
) | ||
|
||
self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group) | ||
|
||
self.requires_cuda = True | ||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
# FairScale MoE assumes that the dimensions are [S, B, E] | ||
# xFormers assumes [B, S, E] | ||
return self.moe(inputs.movedim(0, 1)).movedim(0, 1) |