diff --git a/benchmarks/offload.py b/benchmarks/offload.py index 9eb4ce16f..1a02a6018 100755 --- a/benchmarks/offload.py +++ b/benchmarks/offload.py @@ -34,22 +34,15 @@ def train(args: argparse.Namespace): criterion = nn.CrossEntropyLoss() if args.offload: logging.info("Using sharded offloading for training") - offload_model = OffloadWrapperExperimental( - model_cpu=model, - optimizer=OPTIM, - optimizer_params={"lr": LR}, - device=device, - offload_device=torch.device("cpu"), - n_slices=args.slices, - ) - - optimizer = offload_model.optimizer - model = offload_model # type: ignore + model = OffloadWrapperExperimental( + model_cpu=model, device=device, offload_device=torch.device("cpu"), n_slices=args.slices, + ) # type: ignore else: logging.info("Using Pytorch for training") model = model.to(torch.device("cuda")) - optimizer = OPTIM(model.parameters(), lr=LR) + + optimizer = OPTIM(model.parameters(), lr=LR) transform = ToTensor() dataloader = DataLoader( diff --git a/fairscale/nn/misc/offload.py b/fairscale/nn/misc/offload.py index f7abd452c..8d8fe8f78 100644 --- a/fairscale/nn/misc/offload.py +++ b/fairscale/nn/misc/offload.py @@ -10,7 +10,7 @@ from builtins import isinstance import logging -from typing import Any, Dict, List, Tuple, Type +from typing import Any, List, Tuple import torch from torch import nn @@ -182,8 +182,6 @@ class OffloadWrapperExperimental(nn.Module): def __init__( self, model_cpu: nn.Sequential, # hard pre-requisite for now, easier model slicing - optimizer: Type[torch.optim.Optimizer], - optimizer_params: Dict[str, Any], device: torch.device, offload_device: torch.device = torch.device("cpu"), n_slices: int = 5, @@ -205,10 +203,6 @@ def __init__( ModelShard(cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device,) ) - # Use one normal optimizer per slice - # TODO: Keep all optimizers, return a wrap which will distribute the steps() - self.optimizer = optimizer(nn.Sequential(*split).parameters(), **optimizer_params) - # Expose a unified view of the slices self.model = torch.nn.Sequential(*self.model_slices)