Skip to content

Commit

Permalink
Merge branch 'master' into offload_experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Feb 5, 2021
2 parents 9ec3892 + 7fdd7ec commit d4e929d
Show file tree
Hide file tree
Showing 24 changed files with 788 additions and 1,282 deletions.
17 changes: 16 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,18 @@ run_pipe_benchmark: &run_pipe_benchmark
command: |
python benchmarks/pipe.py
run_mp_pipe_benchmark: &run_mp_pipe_benchmark
- run:
name: Run Multiprocess Pipe Benchmark
command: |
python benchmarks/pipe.py --multiprocess --lazy-construction
run_oss_benchmark: &run_oss_benchmark
- run:
name: Run OSS Benchmark
command: |
python benchmarks/oss.py --world_size 4 --epochs 2
python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp --reference_speed 660 --reference_memory 930 --reference_loss 0.023
python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp
run_oss_gloo: &run_oss_gloo
- run:
Expand All @@ -188,6 +194,12 @@ run_oss_amp: &run_oss_amp
command: |
python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp
run_oss_for_each: &run_oss_for_each
- run:
name: Run OSS with Torch AMP and ForEach optmizer
command: |
python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp --multi_tensor_optim
run_doc_build: &run_doc_build
- run:
Expand Down Expand Up @@ -444,12 +456,15 @@ jobs:

- <<: *run_pipe_benchmark

- <<: *run_mp_pipe_benchmark

- <<: *run_oss_benchmark

- <<: *run_oss_gloo

- <<: *run_oss_amp

- <<: *run_oss_for_each



Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
run: |
python -m cibuildwheel --output-dir dist
env:
CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64 cp39-*64"
CIBW_BUILD: "cp37-*64 cp38-*64 cp39-*64"
CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
CIBW_BEFORE_BUILD: pip install .

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ venv/
ENV/
env.bak/
venv.bak/
.vscode/*
14 changes: 13 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,21 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [next rel] - TBD


## [0.1.5] - 2021-02-03
### Added
- Pytorch compatibility for OSS checkpoints (#310)
- Elastic checkpoints for OSS, world size can vary in between save and loads (#310)
- Tensor views for OSS bucketing, reduced CPU use (#300)
- Bucket calls in ShardedDDP, for faster inter node communications (#327)
- Tensor views for OSS bucketing, reduced CPU use
- FlattenParamWrapper, which flattens module parameters into a single tensor seamlessly (#317)
- AMPnet experimental support (#304)

### Fixed
- ShardedDDP properly handles device changes via `.to()` (#353)
- Add a new interface for AdaScale, AdaScaleWrapper, which makes it compatible with OSS (#347)


## [0.1.4] - 2021-01-07
### Fixed
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/datasets/wikitext2_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import io
import tempfile

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -28,7 +29,8 @@ def get_real_dataloaders(args, benchmark_config):
"""Return real dataloaders for training, testing and validation."""

url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root="/tmp"))
tmpdir = tempfile.TemporaryDirectory()
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=tmpdir.name))
tokenizer = get_tokenizer("basic_english")

def data_process(raw_text_iter):
Expand Down
21 changes: 14 additions & 7 deletions benchmarks/golden_configs/lm_wikitext2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,24 @@ def get_benchmark_config():
"scaler": GradScaler(),
"clip_value": 0.05,
"batch_size": 8,
"num_decoder_layers": 10,
"seq_len": 32,
}


def get_golden_real_stats():

return {
"avg_wps": 703.778,
"std_dev_wps": 5.732,
"peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496],
}
def get_golden_real_stats(multiprocess=False):
if not multiprocess:
return {
"avg_wps": 703.778,
"std_dev_wps": 5.732,
"peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496],
}
else:
return {
"avg_wps": 647.404,
"std_dev_wps": 14.51,
"peak_mem_usage": [3305007616, 2578692608, 3304524288, 2578692608],
}


def get_golden_synthetic_stats():
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/golden_configs/oss_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
def get_golden_real_stats():

return {
"reference_speed": 1430,
"reference_memory": 1220,
"reference_loss": 0.006,
"reference_speed": 660,
"reference_memory": 1000,
"reference_loss": 0.026,
}


Expand Down
128 changes: 17 additions & 111 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,94 +19,31 @@
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, Sampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Resize, ToTensor

from fairscale.nn.data_parallel import OffloadDataParallelExperimental as OffloadDDPExperimental
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler

OPTIM = torch.optim.RMSprop
TEMPDIR = tempfile.gettempdir()


class ReshapeModule(torch.nn.Module):
def forward(self, x):
x = x.view(x.size(0), -1)
return x


class ReshapeTokens(torch.nn.Module):
def __init__(self, vit):
super().__init__()
self.patch_embed = vit.patch_embed
self.cls_token = vit.cls_token

def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)

cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
return torch.cat((cls_token, x), dim=1)


class Norm(torch.nn.Module):
def __init__(self, vit):
super().__init__()
self.norm = vit.norm

def forward(self, x):
x = self.norm(x)
return x[:, 0]


class PosEmbed(torch.nn.Module):
def __init__(self, vit):
super().__init__()
self.pos_embed = vit.pos_embed
self.pos_drop = vit.pos_drop

def forward(self, x):
return self.pos_drop(x + self.pos_embed)


def dist_init(rank, world_size, backend):
logging.info(f"Using backend: {backend}")
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)


def get_problem(rank, world_size, batch_size, device, model_name: str, unroll_model: bool = False):
def get_problem(rank, world_size, batch_size, device, model_name: str):
# Select the desired model on the fly
logging.info(f"Using {model_name} for benchmarking")

try:
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False)
model = getattr(importlib.import_module("torchvision.models"), model_name)(pretrained=False).to(device)
except AttributeError:
model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False)

# Tentatively unroll the model
if unroll_model:
if "resnet" in model_name:
model = torch.nn.Sequential(
model.conv1,
model.bn1,
model.relu,
model.maxpool,
*model.layer1,
*model.layer2,
*model.layer3,
*model.layer4,
model.avgpool,
ReshapeModule(),
model.fc,
)
elif "vit" in model_name:
model = torch.nn.Sequential(ReshapeTokens(model), PosEmbed(model), *model.blocks, Norm(model), model.head)
else:
raise RuntimeError("This model cannot be unrolled")
model = getattr(importlib.import_module("timm.models"), model_name)(pretrained=False).to(device)

# Data setup, duplicate the grey channels to get pseudo color
def collate(inputs: List[Any]):
Expand Down Expand Up @@ -137,11 +74,10 @@ class OptimType(str, Enum):
vanilla = "pytorch"
oss_ddp = "oss_ddp"
oss_sharded_ddp = "oss_sharded_ddp"
oss_offload_ddp = "oss_offload_ddp"
everyone = "everyone"


def validate_benchmark(measurements, args, check_regression):
def validate_benchmark(measurements, final_loss, args, check_regression):
"""Validate the measurments against the golden benchmark config."""

golden_data = oss_mnist.get_golden_real_stats()
Expand Down Expand Up @@ -181,6 +117,10 @@ def train(
):
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)

use_multi_tensor = args.multi_tensor_optim and hasattr(torch.optim, "_multi_tensor")
OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop # type: ignore # attr is checked but mypy misses that
logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))

# DDP
dist_init(rank=rank, world_size=args.world_size, backend=backend)

Expand All @@ -190,52 +130,30 @@ def train(
torch.cuda.manual_seed(0)
torch.manual_seed(0) # also sets the cuda seed
np.random.seed(0)
torch.cuda.device(rank)

if backend == "nccl":
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cpu") if args.cpu else torch.device(rank)
model, dataloader, loss_fn = get_problem(
rank,
args.world_size,
args.batch_size,
device,
args.model,
unroll_model=optim_type == OptimType.oss_offload_ddp,
)
model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model)

# Shard the optimizer, test different methods
# Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None
model = cast(nn.Module, model)
scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None

if optim_type == OptimType.oss_sharded_ddp:
model = model.to(device)
optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
model = ShardedDDP(model, optimizer)
elif optim_type == OptimType.oss_offload_ddp:
ddp_exp = OffloadDDPExperimental(
model_cpu=model,
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=args.world_size,
device=torch.device(torch.cuda.current_device()),
offload_device=torch.device("cpu"),
)
optimizer = ddp_exp.optimizer
model = ddp_exp
else:
model = model.to(device)
device_ids = None if args.cpu else [rank]
model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore
optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if optim_type == OptimType.oss_ddp
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
)

optimizer = cast(torch.optim.Optimizer, optimizer)

# Reset the memory use counter
Expand All @@ -250,7 +168,6 @@ def train(

measurements = []
final_loss: Optional[float] = -1.0
optimizer = cast(Optimizer, optimizer)
need_profiling = args.profile

for epoch in range(args.epochs):
Expand Down Expand Up @@ -346,7 +263,7 @@ def run_closure(closure, scaler, optimizer):
img_per_sec = n_items / (training_stop - training_start) * args.epochs
logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")

validate_benchmark(measurements, args, check_regression)
validate_benchmark(measurements, final_loss, args, check_regression)

dist.destroy_process_group() # type: ignore

Expand All @@ -359,9 +276,6 @@ def run_closure(closure, scaler, optimizer):
parser.add_argument("--epochs", action="store", default=10, type=int)
parser.add_argument("--batch_size", action="store", default=256, type=int)
parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=1430, type=float)
parser.add_argument("--reference_memory", action="store", default=1220, type=float)
parser.add_argument("--reference_loss", action="store", default=0.006, type=float)
parser.add_argument(
"--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
)
Expand All @@ -371,7 +285,10 @@ def run_closure(closure, scaler, optimizer):
parser.add_argument("--model", type=str, help="Any torchvision or timm model name (str)", default="resnet101")
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
parser.add_argument("--fake_data", action="store_true", default=False, help="Use fake data")
parser.add_argument(
"--multi_tensor_optim", action="store_true", default=False, help="Use the faster multi-tensor optimizers"
)

args = parser.parse_args()

logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
Expand Down Expand Up @@ -418,18 +335,7 @@ def run_closure(closure, scaler, optimizer):
logging.info("\n*** Benchmark OSS with ShardedDDP")
mp.spawn(
train, # type: ignore
args=(
args,
BACKEND,
OptimType.oss_sharded_ddp,
False,
), # FIXME: @lefaudeux - SDP should give the same results
args=(args, BACKEND, OptimType.oss_sharded_ddp, args.check_regression,),
nprocs=args.world_size,
join=True,
)

if args.optim_type == OptimType.oss_offload_ddp or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS experimental")
mp.spawn(
train, args=(args, BACKEND, OptimType.oss_offload_ddp, False,), nprocs=args.world_size, join=True, # type: ignore
)
Loading

0 comments on commit d4e929d

Please sign in to comment.