Skip to content

Commit

Permalink
removing the ill-advised optimizer bits, better keep that orthogonal
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Feb 5, 2021
1 parent 8e92a4c commit 6bfeaed
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
17 changes: 5 additions & 12 deletions benchmarks/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,15 @@ def train(args: argparse.Namespace):
criterion = nn.CrossEntropyLoss()
if args.offload:
logging.info("Using sharded offloading for training")
offload_model = OffloadWrapperExperimental(
model_cpu=model,
optimizer=OPTIM,
optimizer_params={"lr": LR},
device=device,
offload_device=torch.device("cpu"),
n_slices=args.slices,
)

optimizer = offload_model.optimizer
model = offload_model # type: ignore
model = OffloadWrapperExperimental(
model_cpu=model, device=device, offload_device=torch.device("cpu"), n_slices=args.slices,
) # type: ignore

else:
logging.info("Using Pytorch for training")
model = model.to(torch.device("cuda"))
optimizer = OPTIM(model.parameters(), lr=LR)

optimizer = OPTIM(model.parameters(), lr=LR)

transform = ToTensor()
dataloader = DataLoader(
Expand Down
8 changes: 1 addition & 7 deletions fairscale/nn/misc/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from builtins import isinstance
import logging
from typing import Any, Dict, List, Tuple, Type
from typing import Any, List, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -182,8 +182,6 @@ class OffloadWrapperExperimental(nn.Module):
def __init__(
self,
model_cpu: nn.Sequential, # hard pre-requisite for now, easier model slicing
optimizer: Type[torch.optim.Optimizer],
optimizer_params: Dict[str, Any],
device: torch.device,
offload_device: torch.device = torch.device("cpu"),
n_slices: int = 5,
Expand All @@ -205,10 +203,6 @@ def __init__(
ModelShard(cpu_model_shard=nn.Sequential(*split), device=device, offload_device=offload_device,)
)

# Use one normal optimizer per slice
# TODO: Keep all optimizers, return a wrap which will distribute the steps()
self.optimizer = optimizer(nn.Sequential(*split).parameters(), **optimizer_params)

# Expose a unified view of the slices
self.model = torch.nn.Sequential(*self.model_slices)

Expand Down

0 comments on commit 6bfeaed

Please sign in to comment.