diff --git a/benchmarks/offload.py b/benchmarks/offload.py index 2f2d6679b..5ebbc78b3 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -3,6 +3,7 @@ # Apply CPU offload and problem sharding to a big transformer model import argparse +import contextlib import logging import time @@ -18,6 +19,27 @@ from fairscale.nn.misc.offload import OffloadModel +def _get_fp16_context(use_fp16=False): + if use_fp16: + return torch.cuda.amp.autocast() + else: + return contextlib.nullcontext() + + +def _get_profiler_context(use_profiler=False): + if use_profiler: + return torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) + else: + return contextlib.nullcontext() + + +def _get_profiler_record_context(record_name, use_profiler=False): + if use_profiler: + return torch.autograd.profiler.record_function(record_name) + else: + return contextlib.nullcontext() + + def train(args: argparse.Namespace): logging.basicConfig(level=logging.INFO) device = torch.device("cuda") @@ -57,22 +79,25 @@ def train_epoch(args): for batch_inputs, batch_outputs in dataloader: batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") start = time.time_ns() - with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof: + with _get_profiler_context() as prof: optimizer.zero_grad() inputs = batch_inputs.reshape(-1, args.inputs * args.inputs) - with torch.autograd.profiler.record_function("model_training"): - output = model(inputs) - loss = criterion(output, target=batch_outputs) - loss.backward() + with _get_profiler_record_context("model_training"): + with _get_fp16_context(use_fp16=args.use_fp16): + output = model(inputs) + loss = criterion(output, target=batch_outputs) + loss.backward() optimizer.step() - prof.export_chrome_trace("/tmp/mpi_prof") - logging.info(f"Memory table {prof.key_averages().table()}") - logging.info("Memory stats are " + str(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30)) + logging.info( + "Memory stats are {:.2f}GB".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30) + ) logging.info( "Loss {:.2f} - throughput {:.2f}fps".format( loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9 ) ) + if args.use_profiler: + prof.export_chrome_trace("/tmp/offload_prof") train_epoch(args) @@ -80,7 +105,7 @@ def train_epoch(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test the CPU offload + sharding with a Transformer training") parser.add_argument("--epochs", action="store", default=1, type=int) - parser.add_argument("--batch_size", action="store", default=1, type=int) + parser.add_argument("--batch_size", action="store", default=20, type=int) parser.add_argument("--inputs", action="store", help="The dimension of the inputs", default=100, type=int) parser.add_argument("--hidden", action="store", help="The dimension of the hidden state", default=10000, type=int) parser.add_argument("--layers", action="store", help="he number of hidden layers", default=20, type=int) @@ -88,6 +113,8 @@ def train_epoch(args): parser.add_argument("--offload", action="store_true", default=False) parser.add_argument("--slices", action="store", default=3, type=int) + parser.add_argument("--use_fp16", action="store_true", default=False) + parser.add_argument("--use_profiler", action="store_true", default=False) args = parser.parse_args() diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 22f61fb86..0328c1061 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -14,6 +14,7 @@ import torch from torch import nn +from torch.cuda.amp import custom_bwd, custom_fwd def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: @@ -125,7 +126,8 @@ class ShardSyncLayer(torch.autograd.Function): """ @staticmethod - def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: # type: ignore + @custom_fwd + def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: drop_index = index load_index = index + 1 max_slices = len(model_slices) @@ -147,6 +149,7 @@ def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance return inputs if isinstance(inputs, tuple) else (inputs,) @staticmethod + @custom_bwd def backward(ctx, *grad_outputs): # type: ignore load_index = ctx.index diff --git a/tests/nn/misc/test_offload.py b/tests/nn/misc/test_offload.py index 858cf7d4e..128202a86 100644 --- a/tests/nn/misc/test_offload.py +++ b/tests/nn/misc/test_offload.py @@ -7,6 +7,7 @@ Testing Offload Module """ +import contextlib import copy import numpy as np @@ -46,7 +47,7 @@ def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2): torch.nn.Linear(num_inputs, num_hidden), *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]), torch.nn.Linear(num_hidden, num_outputs), - ).cuda() + ) return model @@ -65,33 +66,56 @@ def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss): assert torch.allclose(o_buf, reg_buf, atol=1e-2), "Model buffers differ in between Offload and Vanilla." -def test_correctness(): - device, offload_device = _init() - model = _get_model().cuda() +def _get_fp16_context(use_fp16=False): + if use_fp16: + return torch.cuda.amp.autocast() + else: + return contextlib.nullcontext() + - def train(model, optimizer): +def _train(model, optimizer, use_fp16, device): - input = torch.ones(2, 2).to(device) - labels = torch.ones(2, 2).to(device) - model.train() + input = torch.ones(2, 2).to(device) + labels = torch.ones(2, 2).to(device) + loss_fn = torch.nn.MSELoss(reduction="sum") + model.train() + with _get_fp16_context(use_fp16): pred = model(input) - loss_fn = torch.nn.MSELoss(reduction="sum") loss = loss_fn(pred, labels) loss.backward() - optimizer.step() - return model, optimizer, loss - - def train_reg_model(): - reg_model = copy.deepcopy(model) - reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001) - return train(reg_model, reg_optimizer) - - def train_offload_model(): - omodel = copy.deepcopy(model) - offload_model = OffloadModel(model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2,) - offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) - return train(offload_model, offload_optimizer) - - rmodel, ropt, rloss = train_reg_model() - omodel, oopt, oloss = train_offload_model() + optimizer.step() + return model, optimizer, loss + + +def _train_reg_model(model, device, offload_device, use_fp16): + reg_model = copy.deepcopy(model) + reg_model = reg_model.cuda() + reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001) + return _train(reg_model, reg_optimizer, use_fp16, device) + + +def _train_offload_model(model, device, offload_device, use_fp16): + omodel = copy.deepcopy(model) + offload_model = OffloadModel(model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2,) + offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) + return _train(offload_model, offload_optimizer, use_fp16, device) + + +def test_correctness(): + device, offload_device = _init() + model = _get_model() + use_fp16 = False + rmodel, ropt, rloss = _train_reg_model(model, device, offload_device, use_fp16) + omodel, oopt, oloss = _train_offload_model(model, device, offload_device, use_fp16) + _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss) + + +def test_correctness_fp16(): + if not hasattr(torch.cuda.amp, "autocast"): + return + device, offload_device = _init() + model = _get_model() + use_fp16 = True + rmodel, ropt, rloss = _train_reg_model(model, device, offload_device, use_fp16) + omodel, oopt, oloss = _train_offload_model(model, device, offload_device, use_fp16) _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss)