Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[offload] Add support for fp16 training #374

Merged
merged 17 commits into from
Feb 12, 2021
45 changes: 36 additions & 9 deletions benchmarks/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Apply CPU offload and problem sharding to a big transformer model

import argparse
import contextlib
import logging
import time

Expand All @@ -18,6 +19,27 @@
from fairscale.nn.misc.offload import OffloadModel


def _get_fp16_context(use_fp16=False):
if use_fp16:
return torch.cuda.amp.autocast()
else:
return contextlib.nullcontext()


def _get_profiler_context(use_profiler=False):
if use_profiler:
return torch.autograd.profiler.profile(use_cuda=True, profile_memory=True)
else:
return contextlib.nullcontext()


def _get_profiler_record_context(record_name, use_profiler=False):
if use_profiler:
return torch.autograd.profiler.record_function(record_name)
else:
return contextlib.nullcontext()


def train(args: argparse.Namespace):
logging.basicConfig(level=logging.INFO)
device = torch.device("cuda")
Expand Down Expand Up @@ -57,37 +79,42 @@ def train_epoch(args):
for batch_inputs, batch_outputs in dataloader:
batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
start = time.time_ns()
with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof:
with _get_profiler_context() as prof:
optimizer.zero_grad()
inputs = batch_inputs.reshape(-1, args.inputs * args.inputs)
with torch.autograd.profiler.record_function("model_training"):
output = model(inputs)
loss = criterion(output, target=batch_outputs)
loss.backward()
with _get_profiler_record_context("model_training"):
with _get_fp16_context(use_fp16=args.use_fp16):
output = model(inputs)
loss = criterion(output, target=batch_outputs)
loss.backward()
optimizer.step()
prof.export_chrome_trace("/tmp/mpi_prof")
logging.info(f"Memory table {prof.key_averages().table()}")
logging.info("Memory stats are " + str(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30))
logging.info(
"Memory stats are {:.2f}GB".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dummy thought earlier in the day: it would be great to add some ballpark computation of the expected size at some point (given the batch size + model/shards), just for comparison

)
logging.info(
"Loss {:.2f} - throughput {:.2f}fps".format(
loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9
)
)
if args.use_profiler:
prof.export_chrome_trace("/tmp/offload_prof")

train_epoch(args)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test the CPU offload + sharding with a Transformer training")
parser.add_argument("--epochs", action="store", default=1, type=int)
parser.add_argument("--batch_size", action="store", default=1, type=int)
parser.add_argument("--batch_size", action="store", default=20, type=int)
parser.add_argument("--inputs", action="store", help="The dimension of the inputs", default=100, type=int)
parser.add_argument("--hidden", action="store", help="The dimension of the hidden state", default=10000, type=int)
parser.add_argument("--layers", action="store", help="he number of hidden layers", default=20, type=int)
parser.add_argument("--outputs", action="store", help="The number of predicted classes", default=5, type=int)

parser.add_argument("--offload", action="store_true", default=False)
parser.add_argument("--slices", action="store", default=3, type=int)
parser.add_argument("--use_fp16", action="store_true", default=False)
parser.add_argument("--use_profiler", action="store_true", default=False)

args = parser.parse_args()

Expand Down
5 changes: 4 additions & 1 deletion fairscale/nn/misc/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
from torch import nn
from torch.cuda.amp import custom_bwd, custom_fwd


def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]:
Expand Down Expand Up @@ -125,7 +126,8 @@ class ShardSyncLayer(torch.autograd.Function):
"""

@staticmethod
def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: # type: ignore
@custom_fwd
def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any:
drop_index = index
load_index = index + 1
max_slices = len(model_slices)
Expand All @@ -147,6 +149,7 @@ def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance
return inputs if isinstance(inputs, tuple) else (inputs,)

@staticmethod
@custom_bwd
Copy link
Contributor

@blefaudeux blefaudeux Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you elaborate on that ? it's new to me
edit: I looked it up, sorry for the noise, makes sense

def backward(ctx, *grad_outputs): # type: ignore

load_index = ctx.index
Expand Down
74 changes: 49 additions & 25 deletions tests/nn/misc/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Testing Offload Module
"""

import contextlib
import copy

import numpy as np
Expand Down Expand Up @@ -46,7 +47,7 @@ def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2):
torch.nn.Linear(num_inputs, num_hidden),
*([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
torch.nn.Linear(num_hidden, num_outputs),
).cuda()
)
return model


Expand All @@ -65,33 +66,56 @@ def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss):
assert torch.allclose(o_buf, reg_buf, atol=1e-2), "Model buffers differ in between Offload and Vanilla."


def test_correctness():
device, offload_device = _init()
model = _get_model().cuda()
def _get_fp16_context(use_fp16=False):
if use_fp16:
return torch.cuda.amp.autocast()
else:
return contextlib.nullcontext()


def train(model, optimizer):
def _train(model, optimizer, use_fp16, device):

input = torch.ones(2, 2).to(device)
labels = torch.ones(2, 2).to(device)
model.train()
input = torch.ones(2, 2).to(device)
labels = torch.ones(2, 2).to(device)
loss_fn = torch.nn.MSELoss(reduction="sum")
model.train()
with _get_fp16_context(use_fp16):
pred = model(input)
loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, labels)
loss.backward()
optimizer.step()
return model, optimizer, loss

def train_reg_model():
reg_model = copy.deepcopy(model)
reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001)
return train(reg_model, reg_optimizer)

def train_offload_model():
omodel = copy.deepcopy(model)
offload_model = OffloadModel(model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2,)
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
return train(offload_model, offload_optimizer)

rmodel, ropt, rloss = train_reg_model()
omodel, oopt, oloss = train_offload_model()
optimizer.step()
return model, optimizer, loss


def _train_reg_model(model, device, offload_device, use_fp16):
reg_model = copy.deepcopy(model)
reg_model = reg_model.cuda()
reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001)
return _train(reg_model, reg_optimizer, use_fp16, device)


def _train_offload_model(model, device, offload_device, use_fp16):
omodel = copy.deepcopy(model)
offload_model = OffloadModel(model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2,)
offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
return _train(offload_model, offload_optimizer, use_fp16, device)


def test_correctness():
device, offload_device = _init()
model = _get_model()
use_fp16 = False
rmodel, ropt, rloss = _train_reg_model(model, device, offload_device, use_fp16)
omodel, oopt, oloss = _train_offload_model(model, device, offload_device, use_fp16)
_check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss)


def test_correctness_fp16():
if not hasattr(torch.cuda.amp, "autocast"):
return
device, offload_device = _init()
model = _get_model()
use_fp16 = True
rmodel, ropt, rloss = _train_reg_model(model, device, offload_device, use_fp16)
omodel, oopt, oloss = _train_offload_model(model, device, offload_device, use_fp16)
_check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss)