From 89451bfd4c6b0340bbb1997f59eb71f0201914a0 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Fri, 5 Feb 2021 14:01:55 -0800 Subject: [PATCH 01/16] initial fwd/bwd commit --- benchmarks/offload.py | 2 +- fairscale/nn/misc/offload.py | 96 ++++++++++++++++++++++++++++++------ 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index 1a02a6018..d64ea9c7e 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -53,7 +53,7 @@ def train(args: argparse.Namespace): def train_epoch(): model.train() for batch_inputs, batch_outputs in dataloader: - batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") + # batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") start = time.time_ns() optimizer.zero_grad() diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 8d8fe8f78..51fd1999c 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -16,6 +16,70 @@ from torch import nn +# Custom Autograd function to carry out the backward pass +class OffloadBackwardFunction(torch.autograd.Function): + """ + We can implement our own custom autograd Functions by subclassing + torch.autograd.Function and implementing the forward and backward passes + which operate on Tensors. + """ + + @staticmethod + def forward(ctx, inputs, model_instance): + """ + In the forward pass we receive a Tensor containing the input and return + a Tensor containing the output. ctx is a context object that can be used + to stash information for backward computation. You can cache arbitrary + objects for use in the backward pass using the ctx.save_for_backward method. + """ + ctx.save_for_backward(inputs, model_instance) + # List of input activations starting with the given input + model_instance._activations = [inputs] + # Enumerate through layer shards and apply activations from the previous shard + for index, layer_shard in enumerate(model_instance.model_slices): + # Bring in the current activations onto the device + current_input = model_instance._activations[index].to("cuda") + # Bring in the current layer shard onto the device + layer_shard.to("cuda") + # Apply the FP and store the activations on the CPU. + model_instance._activations.append(layer_shard(current_input).to("cpu")) + # Move the layer shard back to the CPU + layer_shard.to("cpu") + return model_instance._activations[-1] + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor containing the gradient of the loss + with respect to the output, and we need to compute the gradient of the loss + with respect to the input. + """ + input, model_instance = ctx.save_for_backward + final_grads = grad_output.clone() + all_grads = [final_grads] + # reverse the model shards and iterate through them + # calculate the gradients as you go along + logging.info("model_instance._activations ", model_instance._activations) + for model_shard, activation in zip(reverse(model_instance.model_slices), reverse(model_instance._activations[:-1])): + # move the activation to the device + activation.to("cuda") + # move the model shard to the device + model_shard.to("cuda") + # calculate the output of the last shard wrt to the stored activation at the slice boundary. + output = model_shard(activation) + # Get the last gradient calculation + final_grads = all_grads[-1] + # calculate the gradient wrt to the output on the CPU + output.backward(final_grads.to("cpu")) + # Move activation back to the GPU + activation.to("cpu") + # Append the list of grads to the all_grads list and this should be on the CPU + all_grads.append(activation.grad.to("cpu")) + # move the shard back to the cpu + model_shard.to("cpu") + + + def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: number_splits = min(len(modules), number_splits) splits: List[List[nn.Module]] = [[] for _ in range(number_splits)] @@ -206,22 +270,24 @@ def __init__( # Expose a unified view of the slices self.model = torch.nn.Sequential(*self.model_slices) + def forward(self, *inputs: Any, **_: Any) -> Any: # Slice per slice FW, sync in between syncRanks = ShardSyncLayer.apply - # TODO: Rewrite this and make it more flexible, this is ugly - for (p2, p1, n1, n2) in zip( - [None, None, *self.model_slices], - [None, *self.model_slices], - [*self.model_slices, None], - [*self.model_slices, None, None], - ): - - # Per shard FW - inputs = p1(*inputs) if p1 else inputs - - # Call the custom autograd hooks (discard/load slices FW and BW) - inputs = syncRanks((p2, p1, n1, n2), inputs) - - return inputs[0] if len(inputs) == 1 else inputs + # # TODO: Rewrite this and make it more flexible, this is ugly + # for (p2, p1, n1, n2) in zip( + # [None, None, *self.model_slices], + # [None, *self.model_slices], + # [*self.model_slices, None], + # [*self.model_slices, None, None], + # ): + + # # Per shard FW + # inputs = p1(*inputs) if p1 else inputs + + # # Call the custom autograd hooks (discard/load slices FW and BW) + # inputs = syncRanks((p2, p1, n1, n2), inputs) + # return inputs[0] if len(inputs) == 1 else inputs + + return OffloadBackwardFunction.apply(*inputs, self) From c569679114fcd0286e9c6ef210242f7402788aac Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sat, 6 Feb 2021 00:28:08 -0800 Subject: [PATCH 02/16] checkpoint work --- benchmarks/offload.py | 29 +++++++------ fairscale/nn/misc/offload.py | 79 +++++++++++++++++++----------------- 2 files changed, 56 insertions(+), 52 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index d64ea9c7e..61fd88d24 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -14,7 +14,6 @@ OPTIM = torch.optim.SGD LR = 1e-3 -BATCH = 32 from fairscale.nn.misc.offload import OffloadWrapperExperimental @@ -25,9 +24,9 @@ def train(args: argparse.Namespace): # Setup the problem model = torch.nn.Sequential( - torch.nn.Linear(args.inputs * args.inputs, args.hidden), + torch.nn.Linear(args.inputs * args.inputs, args.hidden, bias=False), *([torch.nn.Linear(args.hidden, args.hidden) for _ in range(args.layers)]), - torch.nn.Linear(args.hidden, args.outputs) + torch.nn.Linear(args.hidden, args.outputs, bias=False) ).cpu() # Optim loop @@ -47,10 +46,10 @@ def train(args: argparse.Namespace): transform = ToTensor() dataloader = DataLoader( FakeData(image_size=(1, args.inputs, args.inputs), num_classes=args.outputs, transform=transform), - batch_size=BATCH, + batch_size=args.batch_size, ) - def train_epoch(): + def train_epoch(args): model.train() for batch_inputs, batch_outputs in dataloader: # batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") @@ -60,25 +59,25 @@ def train_epoch(): inputs = batch_inputs.reshape(-1, args.inputs * args.inputs) output = model(inputs) loss = criterion(output, target=batch_outputs) + print(f"loss {loss.item()}") loss.backward() optimizer.step() - print("Loss {:.2f} - throughput {:.2f}fps".format(loss.item(), BATCH / (time.time_ns() - start) * 10 ** 9)) - - train_epoch() + print("Loss {:.2f} - throughput {:.2f}fps".format(loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9)) + train_epoch(args) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test the CPU offload + sharding with a Transformer training") - parser.add_argument("--epochs", action="store", default=10, type=int) - parser.add_argument("--batch_size", action="store", default=20, type=int) - parser.add_argument("--inputs", action="store", help="The dimension of the inputs", default=100, type=int) - parser.add_argument("--hidden", action="store", help="The dimension of the hidden state", default=10000, type=int) - parser.add_argument("--layers", action="store", help="he number of hidden layers", default=20, type=int) - parser.add_argument("--outputs", action="store", help="The number of predicted classes", default=5, type=int) + parser.add_argument("--epochs", action="store", default=1, type=int) + parser.add_argument("--batch_size", action="store", default=1, type=int) + parser.add_argument("--inputs", action="store", help="The dimension of the inputs", default=2, type=int) + parser.add_argument("--hidden", action="store", help="The dimension of the hidden state", default=2, type=int) + parser.add_argument("--layers", action="store", help="he number of hidden layers", default=1, type=int) + parser.add_argument("--outputs", action="store", help="The number of predicted classes", default=2, type=int) parser.add_argument("--offload", action="store_true", default=False) - parser.add_argument("--slices", action="store", default=3, type=int) + parser.add_argument("--slices", action="store", default=2, type=int) args = parser.parse_args() diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 51fd1999c..8975016bc 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -17,7 +17,7 @@ # Custom Autograd function to carry out the backward pass -class OffloadBackwardFunction(torch.autograd.Function): +class OffloadFunction(torch.autograd.Function): """ We can implement our own custom autograd Functions by subclassing torch.autograd.Function and implementing the forward and backward passes @@ -38,13 +38,18 @@ def forward(ctx, inputs, model_instance): # Enumerate through layer shards and apply activations from the previous shard for index, layer_shard in enumerate(model_instance.model_slices): # Bring in the current activations onto the device - current_input = model_instance._activations[index].to("cuda") + model_instance._activations[index] = model_instance._activations[index].to( + "cuda") # Bring in the current layer shard onto the device - layer_shard.to("cuda") + layer_shard = layer_shard.to("cuda") + # Run the FP on the layer shard. + inter_activations = layer_shard(model_instance._activations[index]) + # Move the activations to the CPU + inter_activations = inter_activations.to("cpu") # Apply the FP and store the activations on the CPU. - model_instance._activations.append(layer_shard(current_input).to("cpu")) + model_instance._activations.append(inter_activations) # Move the layer shard back to the CPU - layer_shard.to("cpu") + layer_shard = layer_shard.to("cpu") return model_instance._activations[-1] @staticmethod @@ -54,31 +59,30 @@ def backward(ctx, grad_output): with respect to the output, and we need to compute the gradient of the loss with respect to the input. """ - input, model_instance = ctx.save_for_backward - final_grads = grad_output.clone() + _, model_instance = ctx.save_for_backward + final_grads = grad_output.to("cuda") all_grads = [final_grads] - # reverse the model shards and iterate through them - # calculate the gradients as you go along - logging.info("model_instance._activations ", model_instance._activations) + # reverse the model shards and iterate through them. for model_shard, activation in zip(reverse(model_instance.model_slices), reverse(model_instance._activations[:-1])): # move the activation to the device - activation.to("cuda") + activation = activation.to("cuda") # move the model shard to the device - model_shard.to("cuda") + model_shard = model_shard.to("cuda") # calculate the output of the last shard wrt to the stored activation at the slice boundary. output = model_shard(activation) # Get the last gradient calculation final_grads = all_grads[-1] # calculate the gradient wrt to the output on the CPU - output.backward(final_grads.to("cpu")) + output.backward(final_grads) + # Move activation back to the GPU - activation.to("cpu") - # Append the list of grads to the all_grads list and this should be on the CPU - all_grads.append(activation.grad.to("cpu")) + activation = activation.to("cpu") # move the shard back to the cpu - model_shard.to("cpu") - + model_shard = model_shard.to("cpu") + # Append the list of grads to the all_grads list and this should be on the CPU + all_grads.append(activation) + return all_grads[-1] def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: number_splits = min(len(modules), number_splits) @@ -110,7 +114,8 @@ def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: for i, split in enumerate(splits): current_shard_params = sum(p.numel() for sm in split for p in sm.parameters()) - logging.info(f"Shard {i} holds {current_shard_params/1e6:.2f}M parameters") + # logging.info(f"Shard {i} holds {current_shard_params/1e6:.2f}M parameters") + logging.info(f"Shard {i} holds {current_shard_params:.2f} parameters") return splits @@ -270,24 +275,24 @@ def __init__( # Expose a unified view of the slices self.model = torch.nn.Sequential(*self.model_slices) - - def forward(self, *inputs: Any, **_: Any) -> Any: + def _alternate_forward(self, *inputs: Any, **_: Any) -> Any: # Slice per slice FW, sync in between syncRanks = ShardSyncLayer.apply - # # TODO: Rewrite this and make it more flexible, this is ugly - # for (p2, p1, n1, n2) in zip( - # [None, None, *self.model_slices], - # [None, *self.model_slices], - # [*self.model_slices, None], - # [*self.model_slices, None, None], - # ): - - # # Per shard FW - # inputs = p1(*inputs) if p1 else inputs - - # # Call the custom autograd hooks (discard/load slices FW and BW) - # inputs = syncRanks((p2, p1, n1, n2), inputs) - # return inputs[0] if len(inputs) == 1 else inputs - - return OffloadBackwardFunction.apply(*inputs, self) + # TODO: Rewrite this and make it more flexible, this is ugly + for (p2, p1, n1, n2) in zip( + [None, None, *self.model_slices], + [None, *self.model_slices], + [*self.model_slices, None], + [*self.model_slices, None, None], + ): + # Per shard FW + inputs = p1(*inputs) if p1 else inputs + + # Call the custom autograd hooks (discard/load slices FW and BW) + inputs = syncRanks((p2, p1, n1, n2), inputs) + return inputs[0] if len(inputs) == 1 else inputs + + def forward(self, *inputs: Any, **_: Any) -> Any: + # Defer to the overriden autograd function. + return OffloadFunction.apply(*inputs, self) From 4e357cdcf55f9b148e018599c005ec34780585a6 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sat, 6 Feb 2021 18:01:03 -0800 Subject: [PATCH 03/16] modify shard loop --- benchmarks/offload.py | 12 ++- fairscale/nn/misc/offload.py | 156 ++++++++++++++++++----------------- 2 files changed, 88 insertions(+), 80 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index 61fd88d24..bcd76315b 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -11,6 +11,7 @@ from torch.utils.data.dataloader import DataLoader from torchvision.datasets import FakeData from torchvision.transforms import ToTensor +from torchviz import make_dot OPTIM = torch.optim.SGD LR = 1e-3 @@ -21,11 +22,13 @@ def train(args: argparse.Namespace): logging.basicConfig(level=logging.INFO) device = torch.device("cuda") + torch.cuda.set_device(0) + torch.manual_seed(5) # Setup the problem model = torch.nn.Sequential( torch.nn.Linear(args.inputs * args.inputs, args.hidden, bias=False), - *([torch.nn.Linear(args.hidden, args.hidden) for _ in range(args.layers)]), + # *([torch.nn.Linear(args.hidden, args.hidden) for _ in range(args.layers)]), torch.nn.Linear(args.hidden, args.outputs, bias=False) ).cpu() @@ -52,17 +55,18 @@ def train(args: argparse.Namespace): def train_epoch(args): model.train() for batch_inputs, batch_outputs in dataloader: - # batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") + batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") start = time.time_ns() optimizer.zero_grad() inputs = batch_inputs.reshape(-1, args.inputs * args.inputs) output = model(inputs) + # make_dot(output, dict(model.named_parameters())).render("attached_mine", format="png") loss = criterion(output, target=batch_outputs) - print(f"loss {loss.item()}") + # print(f"loss {loss.item()}") loss.backward() optimizer.step() - + # break print("Loss {:.2f} - throughput {:.2f}fps".format(loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9)) train_epoch(args) diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 8975016bc..aa6f4e5bf 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -14,10 +14,11 @@ import torch from torch import nn +from torchviz import make_dot # Custom Autograd function to carry out the backward pass -class OffloadFunction(torch.autograd.Function): +class OffloadBackwardFunction(torch.autograd.Function): """ We can implement our own custom autograd Functions by subclassing torch.autograd.Function and implementing the forward and backward passes @@ -33,25 +34,9 @@ def forward(ctx, inputs, model_instance): objects for use in the backward pass using the ctx.save_for_backward method. """ ctx.save_for_backward(inputs, model_instance) - # List of input activations starting with the given input - model_instance._activations = [inputs] - # Enumerate through layer shards and apply activations from the previous shard - for index, layer_shard in enumerate(model_instance.model_slices): - # Bring in the current activations onto the device - model_instance._activations[index] = model_instance._activations[index].to( - "cuda") - # Bring in the current layer shard onto the device - layer_shard = layer_shard.to("cuda") - # Run the FP on the layer shard. - inter_activations = layer_shard(model_instance._activations[index]) - # Move the activations to the CPU - inter_activations = inter_activations.to("cpu") - # Apply the FP and store the activations on the CPU. - model_instance._activations.append(inter_activations) - # Move the layer shard back to the CPU - layer_shard = layer_shard.to("cpu") - return model_instance._activations[-1] - + inputs = inputs[0].clone() * 2 + return inputs + @staticmethod def backward(ctx, grad_output): """ @@ -59,30 +44,33 @@ def backward(ctx, grad_output): with respect to the output, and we need to compute the gradient of the loss with respect to the input. """ + print("BACKWARD....") _, model_instance = ctx.save_for_backward - final_grads = grad_output.to("cuda") - all_grads = [final_grads] - # reverse the model shards and iterate through them. - for model_shard, activation in zip(reverse(model_instance.model_slices), reverse(model_instance._activations[:-1])): - # move the activation to the device - activation = activation.to("cuda") - # move the model shard to the device - model_shard = model_shard.to("cuda") - # calculate the output of the last shard wrt to the stored activation at the slice boundary. - output = model_shard(activation) - # Get the last gradient calculation - final_grads = all_grads[-1] - # calculate the gradient wrt to the output on the CPU - output.backward(final_grads) - - # Move activation back to the GPU - activation = activation.to("cpu") - # move the shard back to the cpu - model_shard = model_shard.to("cpu") - # Append the list of grads to the all_grads list and this should be on the CPU - all_grads.append(activation) - - return all_grads[-1] + with torch.enable_grad(): + final_grads = grad_output.to("cuda") + all_grads = [final_grads] + # reverse the model shards and iterate through them + # calculate the gradients as you go along + for model_shard, activation in zip(reverse(model_instance.model_slices), reverse(model_instance._activations[:-1])): + # move the activation to the device + activation = activation.to("cuda") + # move the model shard to the device + model_shard = model_shard.to("cuda") + # calculate the output of the last shard wrt to the stored activation at the slice boundary. + output = model_shard(activation) + # Get the last gradient calculation + final_grads = all_grads[-1] + # calculate the gradient wrt to the output on the CPU + output.backward(final_grads, ) + # Move activation back to the GPU + activation = activation.to("cpu") + # Append the list of grads to the all_grads list and this should be on the CPU + all_grads.append(activation.grad.to("cpu")) + # move the shard back to the cpu + model_shard = model_shard.to("cpu") + print("all_grads ", all_grads) + return all_grads[-1] + def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: number_splits = min(len(modules), number_splits) @@ -129,9 +117,11 @@ class ModelShard(nn.Module): def __init__( self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, + index: int, ): super().__init__() self.model_shard = cpu_model_shard + self.index = index # Save all the parameter sizes to be able to restore them self.device = device @@ -194,33 +184,50 @@ class ShardSyncLayer(torch.autograd.Function): """ @staticmethod - def forward(ctx: Any, layer_window: Tuple[ModelShard, ModelShard, ModelShard, ModelShard], inputs: Any) -> Any: # type: ignore - # Drop the shard we just went through, except if this is the last one in line - if layer_window[1] and layer_window[2]: - layer_window[1].forward_drop(non_blocking=True) + def forward(ctx: Any, inputs: Any, index: int, model_slices: Any) -> Any: # type: ignore + drop_index = index + load_index = index + 1 + max_slices = len(model_slices) + + if drop_index >= 0: + # Move shard from device to offload device. + # logging.info(f"Dropping shard {drop_index}") + model_slices[drop_index].forward_drop() - # Start the load of the next shard in line, opportunistically look ahead - if layer_window[3]: - layer_window[3].forward_load(non_blocking=True) + if load_index < max_slices: + # Load shard from offload device to device. + # logging.info(f"Loading shard{load_index}") + model_slices[load_index].forward_load() - ctx.layer_window = layer_window + ctx.inputs = inputs + ctx.index = index + ctx.model_slices = model_slices return inputs if isinstance(inputs, tuple) else (inputs,) @staticmethod def backward(ctx, *grad_outputs): # type: ignore - if ctx.layer_window[3]: - ctx.layer_window[3].backward_drop(non_blocking=True) - # Opportunistically pre-load ahead of the compute wavefront - if ctx.layer_window[1]: - ctx.layer_window[1].backward_load(non_blocking=True) + load_index = ctx.index + drop_index = load_index + 1 + model_slices = ctx.model_slices + if drop_index < len(model_slices): + # Move shard from device to offload device. + # logging.info(f"Backward Dropping shard {drop_index}") + model_slices[drop_index].backward_drop() + + if load_index >= 0: + # Load shard from offload device to device. + # logging.info(f"Backward Loading shard{load_index}") + model_slices[load_index].backward_load() + # The returned variables need to mirror the forward inputs + # TODO(anj-s): Why do we need to do this? if isinstance(grad_outputs, tuple): - return None, grad_outputs[0] + return grad_outputs[0], None, None - return None, grad_outputs + return grad_outputs, None, None class OffloadWrapperExperimental(nn.Module): @@ -266,33 +273,30 @@ def __init__( # Each rank either owns the slice, or temporarily helps processing it in a data parallel fashion self.model_slices: List[nn.Module] = [] - for split in splits: + for i, split in enumerate(splits): # Add one model handling this slice self.model_slices.append( - ModelShard(cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device,) + ModelShard(cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device,index=i,) ) # Expose a unified view of the slices self.model = torch.nn.Sequential(*self.model_slices) - def _alternate_forward(self, *inputs: Any, **_: Any) -> Any: + # intermediate actiavtions + self._activations = [] + + + def forward(self, *inputs: Any, **_: Any) -> Any: # Slice per slice FW, sync in between syncRanks = ShardSyncLayer.apply - # TODO: Rewrite this and make it more flexible, this is ugly - for (p2, p1, n1, n2) in zip( - [None, None, *self.model_slices], - [None, *self.model_slices], - [*self.model_slices, None], - [*self.model_slices, None, None], - ): - # Per shard FW - inputs = p1(*inputs) if p1 else inputs - + self._activations.append(inputs) + for index in range(-1, len(self.model_slices)): + if index >= 0: + inputs = self.model_slices[index](*inputs) + else: + inputs = inputs # Call the custom autograd hooks (discard/load slices FW and BW) - inputs = syncRanks((p2, p1, n1, n2), inputs) + inputs = syncRanks(inputs, index, self.model_slices) return inputs[0] if len(inputs) == 1 else inputs - - def forward(self, *inputs: Any, **_: Any) -> Any: - # Defer to the overriden autograd function. - return OffloadFunction.apply(*inputs, self) + From 510de2db7f7dab70ba953404beb55578e8f85fd4 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sun, 7 Feb 2021 01:07:00 -0800 Subject: [PATCH 04/16] activation offloading and test to start with --- benchmarks/offload.py | 30 +++++---- fairscale/nn/misc/offload.py | 57 ++++++++++++----- tests/nn/misc/test_offload.py | 114 ++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 26 deletions(-) create mode 100644 tests/nn/misc/test_offload.py diff --git a/benchmarks/offload.py b/benchmarks/offload.py index bcd76315b..5e925f8f7 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -52,24 +52,32 @@ def train(args: argparse.Namespace): batch_size=args.batch_size, ) + def train_epoch(args): model.train() + iter_count = 2 for batch_inputs, batch_outputs in dataloader: batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") - + iter_count -= 1 start = time.time_ns() - optimizer.zero_grad() - inputs = batch_inputs.reshape(-1, args.inputs * args.inputs) - output = model(inputs) - # make_dot(output, dict(model.named_parameters())).render("attached_mine", format="png") - loss = criterion(output, target=batch_outputs) - # print(f"loss {loss.item()}") - loss.backward() - optimizer.step() - # break + with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof: + optimizer.zero_grad() + inputs = batch_inputs.reshape(-1, args.inputs * args.inputs) + output = model(inputs) + # make_dot(output, dict(model.named_parameters())).render("attached_mine", format="png") + loss = criterion(output, target=batch_outputs) + # print(f"loss {loss.item()}") + loss.backward() + optimizer.step() + prof.export_chrome_trace("/tmp/mpi_prof") + print(prof.key_averages().table()) + + print(f"current model parameters {[p for p in model.parameters()]}") + # if iter_count == 0: + break print("Loss {:.2f} - throughput {:.2f}fps".format(loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9)) train_epoch(args) - + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test the CPU offload + sharding with a Transformer training") diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index aa6f4e5bf..4854d6285 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -184,24 +184,25 @@ class ShardSyncLayer(torch.autograd.Function): """ @staticmethod - def forward(ctx: Any, inputs: Any, index: int, model_slices: Any) -> Any: # type: ignore + def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: # type: ignore drop_index = index load_index = index + 1 max_slices = len(model_slices) if drop_index >= 0: # Move shard from device to offload device. - # logging.info(f"Dropping shard {drop_index}") + logging.info(f"Dropping shard {drop_index}") model_slices[drop_index].forward_drop() if load_index < max_slices: # Load shard from offload device to device. - # logging.info(f"Loading shard{load_index}") + logging.info(f"Loading shard{load_index}") model_slices[load_index].forward_load() ctx.inputs = inputs ctx.index = index ctx.model_slices = model_slices + ctx.model_instance = model_instance return inputs if isinstance(inputs, tuple) else (inputs,) @@ -211,23 +212,37 @@ def backward(ctx, *grad_outputs): # type: ignore load_index = ctx.index drop_index = load_index + 1 model_slices = ctx.model_slices + model_instance = ctx.model_instance + + logging.info(f"{model_instance._activations} are the current activations") + + # TODO(anj-s): Are these redundant in the backward pass? + if drop_index == len(model_slices): + # Drop the last activation since it is still on the CPU + # after the loss.backward() call. + model_instance._activations[-1] = \ + tuple([a.cuda() for a in list(model_instance._activations[-1])]) if drop_index < len(model_slices): # Move shard from device to offload device. - # logging.info(f"Backward Dropping shard {drop_index}") + logging.info(f"Backward Dropping shard {drop_index}") model_slices[drop_index].backward_drop() + model_instance._activations[drop_index] = \ + tuple([a.cpu() for a in list(model_instance._activations[drop_index])]) if load_index >= 0: # Load shard from offload device to device. - # logging.info(f"Backward Loading shard{load_index}") + logging.info(f"Backward Loading shard{load_index}") model_slices[load_index].backward_load() - + model_instance._activations[load_index] = \ + tuple([a.cuda() for a in list(model_instance._activations[load_index])]) + # The returned variables need to mirror the forward inputs # TODO(anj-s): Why do we need to do this? if isinstance(grad_outputs, tuple): - return grad_outputs[0], None, None + return grad_outputs[0], None, None, None - return grad_outputs, None, None + return grad_outputs, None, None, None class OffloadWrapperExperimental(nn.Module): @@ -289,14 +304,26 @@ def __init__( def forward(self, *inputs: Any, **_: Any) -> Any: # Slice per slice FW, sync in between syncRanks = ShardSyncLayer.apply - - self._activations.append(inputs) + self._activations = [] + # self._activations.append(inputs) for index in range(-1, len(self.model_slices)): + # print("self._activations ", len(self._activations)) if index >= 0: + # TODO(anj-s): This might be a redundant call since we have the previous + # activation on the device already. + self._activations[index] = tuple([a.cuda() for a in list(self._activations[index])]) + inputs = self._activations[index] inputs = self.model_slices[index](*inputs) - else: - inputs = inputs # Call the custom autograd hooks (discard/load slices FW and BW) - inputs = syncRanks(inputs, index, self.model_slices) - return inputs[0] if len(inputs) == 1 else inputs - + inputs = syncRanks(inputs, index, self.model_slices, self) + self._activations.append(inputs) + if index >= 0: + self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])]) + + # We don't move the last activation/output since the target is present + # on the device. + # TODO(anj-s): It is now a requirement that the target tensors be placed on the + # device. + result = self._activations[-1] + return result[0] if len(result) == 1 else result + \ No newline at end of file diff --git a/tests/nn/misc/test_offload.py b/tests/nn/misc/test_offload.py new file mode 100644 index 000000000..1613ebbaf --- /dev/null +++ b/tests/nn/misc/test_offload.py @@ -0,0 +1,114 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Testing Offload Module +""" + +from contextlib import suppress +import copy +import tempfile +from typing import List + +import numpy as np +import torch +from torch.cuda.amp import GradScaler as TorchGradScaler +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn import Linear, Sequential +from torch.nn.parallel import DistributedDataParallel as DDP + +from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu +from fairscale.nn.misc.offload import OffloadWrapperExperimental + + +def _init(): + torch.cuda.set_device(0) + torch.manual_seed(0) + np.random.seed(0) + device = torch.device("cuda") + offload_device = torch.device("cpu") + return device, offload_device + +def test_single_run(): + device, offload_device = _init() + model = _get_model() + + offload_model = OffloadWrapperExperimental( + model_cpu=model, device=device, offload_device=offload_device, n_slices=2, + ) + offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) + + input = torch.ones(2, 2).to(device) + labels = torch.ones(2, 2).to(device) + offload_model.train() + pred = offload_model(input) + loss_fn = torch.nn.MSELoss(reduction="sum") + loss = loss_fn(pred, labels) + loss.backward() + offload_optimizer.step() + + +def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2): + model = torch.nn.Sequential( + torch.nn.Linear(num_inputs, num_hidden), + *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]), + torch.nn.Linear(num_hidden, num_outputs) + ).cuda() + return model + +def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss): + + for oparams, rparams in zip(omodel.parameters(), rmodel.parameters()): + assert torch.allclose( + oparams, rparams, atol=1e-2 + ), f"Model params are different {oparams} {rparams}" + + for o_pg, reg_pg in zip(oopt.param_groups, ropt.param_groups): + for o_pg, reg_pg in zip(o_pg["params"], reg_pg["params"]): + assert torch.allclose( + o_pg, reg_pg, atol=1e-2 + ), f"Model parameters differ in between Offlad and Vanilla {[o_pg]} {reg_pg}" + + for o_buf, reg_buf in zip(omodel.buffers(), rmodel.buffers()): + assert torch.allclose( + o_buf, reg_buf, atol=1e-2 + ), "Model buffers differ in between Offload and Vanilla." + + + +def test_correctness(): + device, offload_device = _init() + model = _get_model().cuda() + + def train(model, optimizer): + + input = torch.ones(2, 2).to(device) + labels = torch.ones(2, 2).to(device) + model.train() + pred = model(input) + loss_fn = torch.nn.MSELoss(reduction="sum") + loss = loss_fn(pred, labels) + loss.backward() + optimizer.step() + return model, optimizer, loss + + def train_reg_model(): + reg_model = copy.deepcopy(model) + reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001) + return train(reg_model, reg_optimizer) + + def train_offload_model(): + omodel = copy.deepcopy(model) + offload_model = OffloadWrapperExperimental( + model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2, + ) + offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) + return train(offload_model, offload_optimizer) + + rmodel, ropt, rloss = train_reg_model() + omodel, oopt, oloss = train_offload_model() + _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss) + From c8df2a664452e1e09456f706bfeb487365100d1f Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sun, 7 Feb 2021 01:10:07 -0800 Subject: [PATCH 05/16] fix lint errors --- benchmarks/offload.py | 13 +++++++----- fairscale/nn/misc/offload.py | 37 ++++++++++++++++++----------------- tests/nn/misc/test_offload.py | 29 ++++++++------------------- 3 files changed, 35 insertions(+), 44 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index 5e925f8f7..4ebb1f8c3 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -11,7 +11,6 @@ from torch.utils.data.dataloader import DataLoader from torchvision.datasets import FakeData from torchvision.transforms import ToTensor -from torchviz import make_dot OPTIM = torch.optim.SGD LR = 1e-3 @@ -29,7 +28,7 @@ def train(args: argparse.Namespace): model = torch.nn.Sequential( torch.nn.Linear(args.inputs * args.inputs, args.hidden, bias=False), # *([torch.nn.Linear(args.hidden, args.hidden) for _ in range(args.layers)]), - torch.nn.Linear(args.hidden, args.outputs, bias=False) + torch.nn.Linear(args.hidden, args.outputs, bias=False), ).cpu() # Optim loop @@ -52,7 +51,6 @@ def train(args: argparse.Namespace): batch_size=args.batch_size, ) - def train_epoch(args): model.train() iter_count = 2 @@ -75,9 +73,14 @@ def train_epoch(args): print(f"current model parameters {[p for p in model.parameters()]}") # if iter_count == 0: break - print("Loss {:.2f} - throughput {:.2f}fps".format(loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9)) + print( + "Loss {:.2f} - throughput {:.2f}fps".format( + loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9 + ) + ) + train_epoch(args) - + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test the CPU offload + sharding with a Transformer training") diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 4854d6285..4ebfb633c 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -10,11 +10,10 @@ from builtins import isinstance import logging -from typing import Any, List, Tuple +from typing import Any, List import torch from torch import nn -from torchviz import make_dot # Custom Autograd function to carry out the backward pass @@ -36,7 +35,7 @@ def forward(ctx, inputs, model_instance): ctx.save_for_backward(inputs, model_instance) inputs = inputs[0].clone() * 2 return inputs - + @staticmethod def backward(ctx, grad_output): """ @@ -51,7 +50,9 @@ def backward(ctx, grad_output): all_grads = [final_grads] # reverse the model shards and iterate through them # calculate the gradients as you go along - for model_shard, activation in zip(reverse(model_instance.model_slices), reverse(model_instance._activations[:-1])): + for model_shard, activation in zip( + reverse(model_instance.model_slices), reverse(model_instance._activations[:-1]) + ): # move the activation to the device activation = activation.to("cuda") # move the model shard to the device @@ -61,7 +62,7 @@ def backward(ctx, grad_output): # Get the last gradient calculation final_grads = all_grads[-1] # calculate the gradient wrt to the output on the CPU - output.backward(final_grads, ) + output.backward(final_grads,) # Move activation back to the GPU activation = activation.to("cpu") # Append the list of grads to the all_grads list and this should be on the CPU @@ -116,8 +117,7 @@ class ModelShard(nn.Module): """ def __init__( - self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, - index: int, + self, cpu_model_shard: nn.Module, device: torch.device, offload_device: torch.device, index: int, ): super().__init__() self.model_shard = cpu_model_shard @@ -214,28 +214,29 @@ def backward(ctx, *grad_outputs): # type: ignore model_slices = ctx.model_slices model_instance = ctx.model_instance - logging.info(f"{model_instance._activations} are the current activations") + logging.info(f"{model_instance._activations} are the current activations") # TODO(anj-s): Are these redundant in the backward pass? if drop_index == len(model_slices): # Drop the last activation since it is still on the CPU # after the loss.backward() call. - model_instance._activations[-1] = \ - tuple([a.cuda() for a in list(model_instance._activations[-1])]) + model_instance._activations[-1] = tuple([a.cuda() for a in list(model_instance._activations[-1])]) if drop_index < len(model_slices): # Move shard from device to offload device. logging.info(f"Backward Dropping shard {drop_index}") model_slices[drop_index].backward_drop() - model_instance._activations[drop_index] = \ - tuple([a.cpu() for a in list(model_instance._activations[drop_index])]) + model_instance._activations[drop_index] = tuple( + [a.cpu() for a in list(model_instance._activations[drop_index])] + ) if load_index >= 0: # Load shard from offload device to device. logging.info(f"Backward Loading shard{load_index}") model_slices[load_index].backward_load() - model_instance._activations[load_index] = \ - tuple([a.cuda() for a in list(model_instance._activations[load_index])]) + model_instance._activations[load_index] = tuple( + [a.cuda() for a in list(model_instance._activations[load_index])] + ) # The returned variables need to mirror the forward inputs # TODO(anj-s): Why do we need to do this? @@ -291,7 +292,9 @@ def __init__( for i, split in enumerate(splits): # Add one model handling this slice self.model_slices.append( - ModelShard(cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device,index=i,) + ModelShard( + cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i, + ) ) # Expose a unified view of the slices @@ -300,7 +303,6 @@ def __init__( # intermediate actiavtions self._activations = [] - def forward(self, *inputs: Any, **_: Any) -> Any: # Slice per slice FW, sync in between syncRanks = ShardSyncLayer.apply @@ -325,5 +327,4 @@ def forward(self, *inputs: Any, **_: Any) -> Any: # TODO(anj-s): It is now a requirement that the target tensors be placed on the # device. result = self._activations[-1] - return result[0] if len(result) == 1 else result - \ No newline at end of file + return result[0] if len(result) == 1 else result diff --git a/tests/nn/misc/test_offload.py b/tests/nn/misc/test_offload.py index 1613ebbaf..5d1cf1642 100644 --- a/tests/nn/misc/test_offload.py +++ b/tests/nn/misc/test_offload.py @@ -7,20 +7,11 @@ Testing Offload Module """ -from contextlib import suppress import copy -import tempfile -from typing import List import numpy as np import torch -from torch.cuda.amp import GradScaler as TorchGradScaler -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.nn import Linear, Sequential -from torch.nn.parallel import DistributedDataParallel as DDP -from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu from fairscale.nn.misc.offload import OffloadWrapperExperimental @@ -32,10 +23,11 @@ def _init(): offload_device = torch.device("cpu") return device, offload_device + def test_single_run(): device, offload_device = _init() model = _get_model() - + offload_model = OffloadWrapperExperimental( model_cpu=model, device=device, offload_device=offload_device, n_slices=2, ) @@ -55,16 +47,15 @@ def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2): model = torch.nn.Sequential( torch.nn.Linear(num_inputs, num_hidden), *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]), - torch.nn.Linear(num_hidden, num_outputs) + torch.nn.Linear(num_hidden, num_outputs), ).cuda() return model + def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss): for oparams, rparams in zip(omodel.parameters(), rmodel.parameters()): - assert torch.allclose( - oparams, rparams, atol=1e-2 - ), f"Model params are different {oparams} {rparams}" + assert torch.allclose(oparams, rparams, atol=1e-2), f"Model params are different {oparams} {rparams}" for o_pg, reg_pg in zip(oopt.param_groups, ropt.param_groups): for o_pg, reg_pg in zip(o_pg["params"], reg_pg["params"]): @@ -73,16 +64,13 @@ def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss): ), f"Model parameters differ in between Offlad and Vanilla {[o_pg]} {reg_pg}" for o_buf, reg_buf in zip(omodel.buffers(), rmodel.buffers()): - assert torch.allclose( - o_buf, reg_buf, atol=1e-2 - ), "Model buffers differ in between Offload and Vanilla." + assert torch.allclose(o_buf, reg_buf, atol=1e-2), "Model buffers differ in between Offload and Vanilla." - def test_correctness(): device, offload_device = _init() model = _get_model().cuda() - + def train(model, optimizer): input = torch.ones(2, 2).to(device) @@ -107,8 +95,7 @@ def train_offload_model(): ) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) return train(offload_model, offload_optimizer) - + rmodel, ropt, rloss = train_reg_model() omodel, oopt, oloss = train_offload_model() _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss) - From 6d53152460dff162cfce93653179446e6e9256d4 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sun, 7 Feb 2021 01:20:25 -0800 Subject: [PATCH 06/16] update comments --- benchmarks/offload.py | 15 +++---- fairscale/nn/misc/offload.py | 82 +++++------------------------------- 2 files changed, 16 insertions(+), 81 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index 4ebb1f8c3..dbff68599 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -27,7 +27,7 @@ def train(args: argparse.Namespace): # Setup the problem model = torch.nn.Sequential( torch.nn.Linear(args.inputs * args.inputs, args.hidden, bias=False), - # *([torch.nn.Linear(args.hidden, args.hidden) for _ in range(args.layers)]), + *([torch.nn.Linear(args.hidden, args.hidden) for _ in range(args.layers)]), torch.nn.Linear(args.hidden, args.outputs, bias=False), ).cpu() @@ -70,9 +70,6 @@ def train_epoch(args): prof.export_chrome_trace("/tmp/mpi_prof") print(prof.key_averages().table()) - print(f"current model parameters {[p for p in model.parameters()]}") - # if iter_count == 0: - break print( "Loss {:.2f} - throughput {:.2f}fps".format( loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9 @@ -86,13 +83,13 @@ def train_epoch(args): parser = argparse.ArgumentParser(description="Test the CPU offload + sharding with a Transformer training") parser.add_argument("--epochs", action="store", default=1, type=int) parser.add_argument("--batch_size", action="store", default=1, type=int) - parser.add_argument("--inputs", action="store", help="The dimension of the inputs", default=2, type=int) - parser.add_argument("--hidden", action="store", help="The dimension of the hidden state", default=2, type=int) - parser.add_argument("--layers", action="store", help="he number of hidden layers", default=1, type=int) - parser.add_argument("--outputs", action="store", help="The number of predicted classes", default=2, type=int) + parser.add_argument("--inputs", action="store", help="The dimension of the inputs", default=100, type=int) + parser.add_argument("--hidden", action="store", help="The dimension of the hidden state", default=10000, type=int) + parser.add_argument("--layers", action="store", help="he number of hidden layers", default=20, type=int) + parser.add_argument("--outputs", action="store", help="The number of predicted classes", default=5, type=int) parser.add_argument("--offload", action="store_true", default=False) - parser.add_argument("--slices", action="store", default=2, type=int) + parser.add_argument("--slices", action="store", default=3, type=int) args = parser.parse_args() diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 4ebfb633c..5a41f3f2c 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -16,63 +16,6 @@ from torch import nn -# Custom Autograd function to carry out the backward pass -class OffloadBackwardFunction(torch.autograd.Function): - """ - We can implement our own custom autograd Functions by subclassing - torch.autograd.Function and implementing the forward and backward passes - which operate on Tensors. - """ - - @staticmethod - def forward(ctx, inputs, model_instance): - """ - In the forward pass we receive a Tensor containing the input and return - a Tensor containing the output. ctx is a context object that can be used - to stash information for backward computation. You can cache arbitrary - objects for use in the backward pass using the ctx.save_for_backward method. - """ - ctx.save_for_backward(inputs, model_instance) - inputs = inputs[0].clone() * 2 - return inputs - - @staticmethod - def backward(ctx, grad_output): - """ - In the backward pass we receive a Tensor containing the gradient of the loss - with respect to the output, and we need to compute the gradient of the loss - with respect to the input. - """ - print("BACKWARD....") - _, model_instance = ctx.save_for_backward - with torch.enable_grad(): - final_grads = grad_output.to("cuda") - all_grads = [final_grads] - # reverse the model shards and iterate through them - # calculate the gradients as you go along - for model_shard, activation in zip( - reverse(model_instance.model_slices), reverse(model_instance._activations[:-1]) - ): - # move the activation to the device - activation = activation.to("cuda") - # move the model shard to the device - model_shard = model_shard.to("cuda") - # calculate the output of the last shard wrt to the stored activation at the slice boundary. - output = model_shard(activation) - # Get the last gradient calculation - final_grads = all_grads[-1] - # calculate the gradient wrt to the output on the CPU - output.backward(final_grads,) - # Move activation back to the GPU - activation = activation.to("cpu") - # Append the list of grads to the all_grads list and this should be on the CPU - all_grads.append(activation.grad.to("cpu")) - # move the shard back to the cpu - model_shard = model_shard.to("cpu") - print("all_grads ", all_grads) - return all_grads[-1] - - def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: number_splits = min(len(modules), number_splits) splits: List[List[nn.Module]] = [[] for _ in range(number_splits)] @@ -111,9 +54,8 @@ def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: class ModelShard(nn.Module): """ - Wrap one shard of the model, make it possible to load parameters on the fly for the FW pass and gather gradients. - Depending on whether this rank is or is not the `owner_rank`, this ModelShard either only handles - a shard of the compute and is stateless or also owns the up to date state. + Wrap one shard of the model, make it possible to load parameters on the + fly for the FW and BW pass on the given device. """ def __init__( @@ -214,8 +156,6 @@ def backward(ctx, *grad_outputs): # type: ignore model_slices = ctx.model_slices model_instance = ctx.model_instance - logging.info(f"{model_instance._activations} are the current activations") - # TODO(anj-s): Are these redundant in the backward pass? if drop_index == len(model_slices): # Drop the last activation since it is still on the CPU @@ -253,8 +193,7 @@ class OffloadWrapperExperimental(nn.Module): The model is sharded, then the normal distributed data parallel algorithm can be used on a per-model shard basis. Each shard is offloaded and loaded following a compute wavefront, during the forward and backward pass. - All the gradients are centralized on a given rank (which is model-shard dependent, so that the gradients - redundancy can be removed). Each model shard can be updated by a normal pytorch optimizer. + Each model shard can be updated by a normal pytorch optimizer. Args: module (~torch.nn.Sequential): module to be parallelized @@ -283,17 +222,19 @@ def __init__( self.device = device self.offload_device = offload_device - # Slice the model into roughly equivalent sequential shards + # Slice the model into roughly equivalent sequential shards. splits = _split(model_cpu, n_slices) - # Each rank either owns the slice, or temporarily helps processing it in a data parallel fashion + # List of model shards that will be placed on/off the device. self.model_slices: List[nn.Module] = [] for i, split in enumerate(splits): # Add one model handling this slice self.model_slices.append( ModelShard( - cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i, + cpu_model_shard=nn.Sequential(*split), device=device, + offload_device=offload_device, + index=i, ) ) @@ -304,12 +245,9 @@ def __init__( self._activations = [] def forward(self, *inputs: Any, **_: Any) -> Any: - # Slice per slice FW, sync in between - syncRanks = ShardSyncLayer.apply + shardSync = ShardSyncLayer.apply self._activations = [] - # self._activations.append(inputs) for index in range(-1, len(self.model_slices)): - # print("self._activations ", len(self._activations)) if index >= 0: # TODO(anj-s): This might be a redundant call since we have the previous # activation on the device already. @@ -317,7 +255,7 @@ def forward(self, *inputs: Any, **_: Any) -> Any: inputs = self._activations[index] inputs = self.model_slices[index](*inputs) # Call the custom autograd hooks (discard/load slices FW and BW) - inputs = syncRanks(inputs, index, self.model_slices, self) + inputs = shardSync(inputs, index, self.model_slices, self) self._activations.append(inputs) if index >= 0: self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])]) From 8bee0cc8871b417694ac0e2cf473ec951057bab2 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sun, 7 Feb 2021 01:20:52 -0800 Subject: [PATCH 07/16] fix lint --- fairscale/nn/misc/offload.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 5a41f3f2c..0c5104a47 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -232,9 +232,7 @@ def __init__( # Add one model handling this slice self.model_slices.append( ModelShard( - cpu_model_shard=nn.Sequential(*split), device=device, - offload_device=offload_device, - index=i, + cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device, index=i, ) ) From a825e96bda4e31c2341440cfee2a5e994ef0f527 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sun, 7 Feb 2021 01:22:12 -0800 Subject: [PATCH 08/16] remove unused var --- benchmarks/offload.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index dbff68599..ce265371e 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -26,9 +26,9 @@ def train(args: argparse.Namespace): # Setup the problem model = torch.nn.Sequential( - torch.nn.Linear(args.inputs * args.inputs, args.hidden, bias=False), + torch.nn.Linear(args.inputs * args.inputs, args.hidden), *([torch.nn.Linear(args.hidden, args.hidden) for _ in range(args.layers)]), - torch.nn.Linear(args.hidden, args.outputs, bias=False), + torch.nn.Linear(args.hidden, args.outputs), ).cpu() # Optim loop @@ -56,15 +56,12 @@ def train_epoch(args): iter_count = 2 for batch_inputs, batch_outputs in dataloader: batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") - iter_count -= 1 start = time.time_ns() with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof: optimizer.zero_grad() inputs = batch_inputs.reshape(-1, args.inputs * args.inputs) output = model(inputs) - # make_dot(output, dict(model.named_parameters())).render("attached_mine", format="png") loss = criterion(output, target=batch_outputs) - # print(f"loss {loss.item()}") loss.backward() optimizer.step() prof.export_chrome_trace("/tmp/mpi_prof") From 29c8c3413383c45a09a2b5e7b4e0a0c27f0dfcda Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sun, 7 Feb 2021 01:23:21 -0800 Subject: [PATCH 09/16] remove commented out lines --- fairscale/nn/misc/offload.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 0c5104a47..2fa35cfc8 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -46,8 +46,7 @@ def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: for i, split in enumerate(splits): current_shard_params = sum(p.numel() for sm in split for p in sm.parameters()) - # logging.info(f"Shard {i} holds {current_shard_params/1e6:.2f}M parameters") - logging.info(f"Shard {i} holds {current_shard_params:.2f} parameters") + logging.info(f"Shard {i} holds {current_shard_params/1e6:.2f}M parameters") return splits From a5d1c882a44d9efb68dfe02974732ccc26d03e9c Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Mon, 8 Feb 2021 12:31:33 -0800 Subject: [PATCH 10/16] modify name --- fairscale/nn/__init__.py | 2 +- fairscale/nn/misc/__init__.py | 2 +- fairscale/nn/misc/offload.py | 16 +++++++++------- tests/nn/misc/test_offload.py | 6 +++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/fairscale/nn/__init__.py b/fairscale/nn/__init__.py index 63e23d363..5e7c23fbe 100644 --- a/fairscale/nn/__init__.py +++ b/fairscale/nn/__init__.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from .data_parallel import ShardedDataParallel -from .misc import FlattenParamsWrapper, OffloadWrapperExperimental +from .misc import FlattenParamsWrapper, OffloadModel from .moe import MOELayer, Top2Gate from .pipe import Pipe, PipeRPCWrapper diff --git a/fairscale/nn/misc/__init__.py b/fairscale/nn/misc/__init__.py index 2c4ed672a..45584a4af 100644 --- a/fairscale/nn/misc/__init__.py +++ b/fairscale/nn/misc/__init__.py @@ -4,4 +4,4 @@ # LICENSE file in the root directory of this source tree. from .flatten_params_wrapper import FlattenParamsWrapper -from .offload import OffloadWrapperExperimental +from .offload import OffloadModel diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 2fa35cfc8..212c56b42 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -165,17 +165,19 @@ def backward(ctx, *grad_outputs): # type: ignore # Move shard from device to offload device. logging.info(f"Backward Dropping shard {drop_index}") model_slices[drop_index].backward_drop() - model_instance._activations[drop_index] = tuple( - [a.cpu() for a in list(model_instance._activations[drop_index])] - ) + with torch.autograd.profiler.record_function("backward:single_slice_to_cpu"): + model_instance._activations[drop_index] = tuple( + [a.cpu() for a in list(model_instance._activations[drop_index])] + ) if load_index >= 0: # Load shard from offload device to device. logging.info(f"Backward Loading shard{load_index}") model_slices[load_index].backward_load() - model_instance._activations[load_index] = tuple( - [a.cuda() for a in list(model_instance._activations[load_index])] - ) + with torch.autograd.profiler.record_function("single_slice_to_cuda"): + model_instance._activations[load_index] = tuple( + [a.cuda() for a in list(model_instance._activations[load_index])] + ) # The returned variables need to mirror the forward inputs # TODO(anj-s): Why do we need to do this? @@ -185,7 +187,7 @@ def backward(ctx, *grad_outputs): # type: ignore return grad_outputs, None, None, None -class OffloadWrapperExperimental(nn.Module): +class OffloadModel(nn.Module): """Implements training with optimizer state sharding and model sharding. This experiments with a different way to get to the full zero suite diff --git a/tests/nn/misc/test_offload.py b/tests/nn/misc/test_offload.py index 5d1cf1642..5f21fa1a7 100644 --- a/tests/nn/misc/test_offload.py +++ b/tests/nn/misc/test_offload.py @@ -12,7 +12,7 @@ import numpy as np import torch -from fairscale.nn.misc.offload import OffloadWrapperExperimental +from fairscale.nn.misc.offload import OffloadModel def _init(): @@ -28,7 +28,7 @@ def test_single_run(): device, offload_device = _init() model = _get_model() - offload_model = OffloadWrapperExperimental( + offload_model = OffloadModel( model_cpu=model, device=device, offload_device=offload_device, n_slices=2, ) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) @@ -90,7 +90,7 @@ def train_reg_model(): def train_offload_model(): omodel = copy.deepcopy(model) - offload_model = OffloadWrapperExperimental( + offload_model = OffloadModel( model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2, ) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) From 48e3e250b24d3677ea9af4650fdc5a482766d41f Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Mon, 8 Feb 2021 12:33:20 -0800 Subject: [PATCH 11/16] remove break --- benchmarks/offload.py | 19 ++++++++++--------- tests/nn/misc/test_offload.py | 8 ++------ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index ce265371e..2f2d6679b 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -15,7 +15,7 @@ OPTIM = torch.optim.SGD LR = 1e-3 -from fairscale.nn.misc.offload import OffloadWrapperExperimental +from fairscale.nn.misc.offload import OffloadModel def train(args: argparse.Namespace): @@ -35,7 +35,7 @@ def train(args: argparse.Namespace): criterion = nn.CrossEntropyLoss() if args.offload: logging.info("Using sharded offloading for training") - model = OffloadWrapperExperimental( + model = OffloadModel( model_cpu=model, device=device, offload_device=torch.device("cpu"), n_slices=args.slices, ) # type: ignore @@ -60,14 +60,15 @@ def train_epoch(args): with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof: optimizer.zero_grad() inputs = batch_inputs.reshape(-1, args.inputs * args.inputs) - output = model(inputs) - loss = criterion(output, target=batch_outputs) - loss.backward() - optimizer.step() + with torch.autograd.profiler.record_function("model_training"): + output = model(inputs) + loss = criterion(output, target=batch_outputs) + loss.backward() + optimizer.step() prof.export_chrome_trace("/tmp/mpi_prof") - print(prof.key_averages().table()) - - print( + logging.info(f"Memory table {prof.key_averages().table()}") + logging.info("Memory stats are " + str(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30)) + logging.info( "Loss {:.2f} - throughput {:.2f}fps".format( loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9 ) diff --git a/tests/nn/misc/test_offload.py b/tests/nn/misc/test_offload.py index 5f21fa1a7..858cf7d4e 100644 --- a/tests/nn/misc/test_offload.py +++ b/tests/nn/misc/test_offload.py @@ -28,9 +28,7 @@ def test_single_run(): device, offload_device = _init() model = _get_model() - offload_model = OffloadModel( - model_cpu=model, device=device, offload_device=offload_device, n_slices=2, - ) + offload_model = OffloadModel(model_cpu=model, device=device, offload_device=offload_device, n_slices=2,) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) input = torch.ones(2, 2).to(device) @@ -90,9 +88,7 @@ def train_reg_model(): def train_offload_model(): omodel = copy.deepcopy(model) - offload_model = OffloadModel( - model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2, - ) + offload_model = OffloadModel(model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2,) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) return train(offload_model, offload_optimizer) From 3befc5c1e2f094e5420212176bb71555ce58f3e1 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Mon, 8 Feb 2021 13:20:46 -0800 Subject: [PATCH 12/16] remove profiler comments --- fairscale/nn/misc/offload.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 212c56b42..faef4990c 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -165,19 +165,17 @@ def backward(ctx, *grad_outputs): # type: ignore # Move shard from device to offload device. logging.info(f"Backward Dropping shard {drop_index}") model_slices[drop_index].backward_drop() - with torch.autograd.profiler.record_function("backward:single_slice_to_cpu"): - model_instance._activations[drop_index] = tuple( - [a.cpu() for a in list(model_instance._activations[drop_index])] - ) + model_instance._activations[drop_index] = tuple( + [a.cpu() for a in list(model_instance._activations[drop_index])] + ) if load_index >= 0: # Load shard from offload device to device. logging.info(f"Backward Loading shard{load_index}") model_slices[load_index].backward_load() - with torch.autograd.profiler.record_function("single_slice_to_cuda"): - model_instance._activations[load_index] = tuple( - [a.cuda() for a in list(model_instance._activations[load_index])] - ) + model_instance._activations[load_index] = tuple( + [a.cuda() for a in list(model_instance._activations[load_index])] + ) # The returned variables need to mirror the forward inputs # TODO(anj-s): Why do we need to do this? From e10a874a535d7a06651c6294753c8e4d8d2638b0 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Tue, 9 Feb 2021 12:39:02 -0800 Subject: [PATCH 13/16] add support for fp16 --- benchmarks/offload.py | 53 +++++++++++++++++++++++++++++------- fairscale/nn/misc/offload.py | 3 ++ 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index 2f2d6679b..dd3b6e14d 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -5,12 +5,13 @@ import argparse import logging import time - +import contextlib import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader from torchvision.datasets import FakeData from torchvision.transforms import ToTensor +from torch.cuda.amp import autocast OPTIM = torch.optim.SGD LR = 1e-3 @@ -18,6 +19,34 @@ from fairscale.nn.misc.offload import OffloadModel +def _get_fp16_context(use_fp16=False): + if use_fp16: + return torch.cuda.amp.autocast() + else: + return contextlib.nullcontext() + + +def _get_fp16_context(use_fp16=False): + if use_fp16: + return torch.cuda.amp.autocast() + else: + return contextlib.nullcontext() + + +def _get_profiler_context(use_profiler=False): + if use_profiler: + return torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) + else: + return contextlib.nullcontext() + + +def _get_profiler_record_context(record_name, use_profiler=False): + if use_profiler: + return torch.autograd.profiler.record_function(record_name) + else: + return contextlib.nullcontext() + + def train(args: argparse.Namespace): logging.basicConfig(level=logging.INFO) device = torch.device("cuda") @@ -57,22 +86,24 @@ def train_epoch(args): for batch_inputs, batch_outputs in dataloader: batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda") start = time.time_ns() - with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof: + with _get_profiler_context() as prof: optimizer.zero_grad() inputs = batch_inputs.reshape(-1, args.inputs * args.inputs) - with torch.autograd.profiler.record_function("model_training"): - output = model(inputs) - loss = criterion(output, target=batch_outputs) - loss.backward() + with _get_profiler_record_context("model_training"): + with _get_fp16_context(use_fp16=args.use_fp16): + output = model(inputs) + loss = criterion(output, target=batch_outputs) + loss.backward() optimizer.step() - prof.export_chrome_trace("/tmp/mpi_prof") - logging.info(f"Memory table {prof.key_averages().table()}") - logging.info("Memory stats are " + str(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30)) + logging.info("Memory stats are {:.2f}GB".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30)) logging.info( "Loss {:.2f} - throughput {:.2f}fps".format( loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9 ) ) + break + if args.use_profiler: + prof.export_chrome_trace("/tmp/offload_prof") train_epoch(args) @@ -80,7 +111,7 @@ def train_epoch(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test the CPU offload + sharding with a Transformer training") parser.add_argument("--epochs", action="store", default=1, type=int) - parser.add_argument("--batch_size", action="store", default=1, type=int) + parser.add_argument("--batch_size", action="store", default=20, type=int) parser.add_argument("--inputs", action="store", help="The dimension of the inputs", default=100, type=int) parser.add_argument("--hidden", action="store", help="The dimension of the hidden state", default=10000, type=int) parser.add_argument("--layers", action="store", help="he number of hidden layers", default=20, type=int) @@ -88,6 +119,8 @@ def train_epoch(args): parser.add_argument("--offload", action="store_true", default=False) parser.add_argument("--slices", action="store", default=3, type=int) + parser.add_argument("--use_fp16", action="store_true", default=False) + parser.add_argument("--use_profiler", action="store_true", default=False) args = parser.parse_args() diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index faef4990c..49dabbe30 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -14,6 +14,7 @@ import torch from torch import nn +from torch.cuda.amp import custom_fwd, custom_bwd def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: @@ -125,6 +126,7 @@ class ShardSyncLayer(torch.autograd.Function): """ @staticmethod + @custom_fwd def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: # type: ignore drop_index = index load_index = index + 1 @@ -148,6 +150,7 @@ def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance return inputs if isinstance(inputs, tuple) else (inputs,) @staticmethod + @custom_bwd def backward(ctx, *grad_outputs): # type: ignore load_index = ctx.index From 3c4e71dae704305a2b6af3732da445dd8012e67f Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Tue, 9 Feb 2021 12:55:06 -0800 Subject: [PATCH 14/16] add unit tests --- benchmarks/offload.py | 8 ---- tests/nn/misc/test_offload.py | 73 ++++++++++++++++++++++------------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index dd3b6e14d..856a48def 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -19,13 +19,6 @@ from fairscale.nn.misc.offload import OffloadModel -def _get_fp16_context(use_fp16=False): - if use_fp16: - return torch.cuda.amp.autocast() - else: - return contextlib.nullcontext() - - def _get_fp16_context(use_fp16=False): if use_fp16: return torch.cuda.amp.autocast() @@ -101,7 +94,6 @@ def train_epoch(args): loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9 ) ) - break if args.use_profiler: prof.export_chrome_trace("/tmp/offload_prof") diff --git a/tests/nn/misc/test_offload.py b/tests/nn/misc/test_offload.py index 858cf7d4e..9aaae4305 100644 --- a/tests/nn/misc/test_offload.py +++ b/tests/nn/misc/test_offload.py @@ -8,11 +8,12 @@ """ import copy - +import contextlib import numpy as np import torch from fairscale.nn.misc.offload import OffloadModel +from torch.cuda.amp import autocast def _init(): @@ -46,7 +47,7 @@ def _get_model(num_inputs=2, num_hidden=2, num_layers=1, num_outputs=2): torch.nn.Linear(num_inputs, num_hidden), *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]), torch.nn.Linear(num_hidden, num_outputs), - ).cuda() + ) return model @@ -64,34 +65,54 @@ def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss): for o_buf, reg_buf in zip(omodel.buffers(), rmodel.buffers()): assert torch.allclose(o_buf, reg_buf, atol=1e-2), "Model buffers differ in between Offload and Vanilla." +def _get_fp16_context(use_fp16=False): + if use_fp16: + return torch.cuda.amp.autocast() + else: + return contextlib.nullcontext() -def test_correctness(): - device, offload_device = _init() - model = _get_model().cuda() - def train(model, optimizer): +def _train(model, optimizer, use_fp16, device): - input = torch.ones(2, 2).to(device) - labels = torch.ones(2, 2).to(device) - model.train() + input = torch.ones(2, 2).to(device) + labels = torch.ones(2, 2).to(device) + loss_fn = torch.nn.MSELoss(reduction="sum") + model.train() + with _get_fp16_context(use_fp16): pred = model(input) - loss_fn = torch.nn.MSELoss(reduction="sum") loss = loss_fn(pred, labels) loss.backward() - optimizer.step() - return model, optimizer, loss - - def train_reg_model(): - reg_model = copy.deepcopy(model) - reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001) - return train(reg_model, reg_optimizer) - - def train_offload_model(): - omodel = copy.deepcopy(model) - offload_model = OffloadModel(model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2,) - offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) - return train(offload_model, offload_optimizer) - - rmodel, ropt, rloss = train_reg_model() - omodel, oopt, oloss = train_offload_model() + optimizer.step() + return model, optimizer, loss + + +def _train_reg_model(model, device, offload_device, use_fp16): + reg_model = copy.deepcopy(model) + reg_model = reg_model.cuda() + reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001) + return _train(reg_model, reg_optimizer, use_fp16, device) + + +def _train_offload_model(model, device, offload_device, use_fp16): + omodel = copy.deepcopy(model) + offload_model = OffloadModel(model_cpu=omodel, device=device, offload_device=offload_device, n_slices=2,) + offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) + return _train(offload_model, offload_optimizer, use_fp16, device) + + +def test_correctness(): + device, offload_device = _init() + model = _get_model() + use_fp16 = False + rmodel, ropt, rloss = _train_reg_model(model, device, offload_device, use_fp16) + omodel, oopt, oloss = _train_offload_model(model, device, offload_device, use_fp16) + _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss) + + +def test_correctness_fp16(): + device, offload_device = _init() + model = _get_model() + use_fp16 = True + rmodel, ropt, rloss = _train_reg_model(model, device, offload_device, use_fp16) + omodel, oopt, oloss = _train_offload_model(model, device, offload_device, use_fp16) _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss) From 901390fc0fd8c0660cf9f6cf37899d694913deee Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Tue, 9 Feb 2021 12:57:29 -0800 Subject: [PATCH 15/16] fix lint errors --- benchmarks/offload.py | 8 +++++--- fairscale/nn/misc/offload.py | 2 +- tests/nn/misc/test_offload.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/benchmarks/offload.py b/benchmarks/offload.py index 856a48def..5ebbc78b3 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -3,15 +3,15 @@ # Apply CPU offload and problem sharding to a big transformer model import argparse +import contextlib import logging import time -import contextlib + import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader from torchvision.datasets import FakeData from torchvision.transforms import ToTensor -from torch.cuda.amp import autocast OPTIM = torch.optim.SGD LR = 1e-3 @@ -88,7 +88,9 @@ def train_epoch(args): loss = criterion(output, target=batch_outputs) loss.backward() optimizer.step() - logging.info("Memory stats are {:.2f}GB".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30)) + logging.info( + "Memory stats are {:.2f}GB".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 2 ** 30) + ) logging.info( "Loss {:.2f} - throughput {:.2f}fps".format( loss.item(), args.batch_size / (time.time_ns() - start) * 10 ** 9 diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index 49dabbe30..e893517fe 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -14,7 +14,7 @@ import torch from torch import nn -from torch.cuda.amp import custom_fwd, custom_bwd +from torch.cuda.amp import custom_bwd, custom_fwd def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: diff --git a/tests/nn/misc/test_offload.py b/tests/nn/misc/test_offload.py index 9aaae4305..06cbeb9be 100644 --- a/tests/nn/misc/test_offload.py +++ b/tests/nn/misc/test_offload.py @@ -7,13 +7,13 @@ Testing Offload Module """ -import copy import contextlib +import copy + import numpy as np import torch from fairscale.nn.misc.offload import OffloadModel -from torch.cuda.amp import autocast def _init(): @@ -65,6 +65,7 @@ def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss): for o_buf, reg_buf in zip(omodel.buffers(), rmodel.buffers()): assert torch.allclose(o_buf, reg_buf, atol=1e-2), "Model buffers differ in between Offload and Vanilla." + def _get_fp16_context(use_fp16=False): if use_fp16: return torch.cuda.amp.autocast() From 1df261996ed4464682fd37e1f65cdbc14a4e7e9f Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Tue, 9 Feb 2021 23:29:33 -0800 Subject: [PATCH 16/16] fix test failure --- fairscale/nn/misc/offload.py | 7 +++---- tests/nn/misc/test_offload.py | 2 ++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index e893517fe..0328c1061 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -10,7 +10,7 @@ from builtins import isinstance import logging -from typing import Any, List +from typing import Any, List, Tuple import torch from torch import nn @@ -127,7 +127,7 @@ class ShardSyncLayer(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: # type: ignore + def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance: Any) -> Any: drop_index = index load_index = index + 1 max_slices = len(model_slices) @@ -142,7 +142,6 @@ def forward(ctx: Any, inputs: Any, index: int, model_slices: Any, model_instance logging.info(f"Loading shard{load_index}") model_slices[load_index].forward_load() - ctx.inputs = inputs ctx.index = index ctx.model_slices = model_slices ctx.model_instance = model_instance @@ -242,7 +241,7 @@ def __init__( self.model = torch.nn.Sequential(*self.model_slices) # intermediate actiavtions - self._activations = [] + self._activations: List[Tuple] = [] def forward(self, *inputs: Any, **_: Any) -> Any: shardSync = ShardSyncLayer.apply diff --git a/tests/nn/misc/test_offload.py b/tests/nn/misc/test_offload.py index 06cbeb9be..128202a86 100644 --- a/tests/nn/misc/test_offload.py +++ b/tests/nn/misc/test_offload.py @@ -111,6 +111,8 @@ def test_correctness(): def test_correctness_fp16(): + if not hasattr(torch.cuda.amp, "autocast"): + return device, offload_device = _init() model = _get_model() use_fp16 = True