Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Dev compression speedup #1999

Merged
merged 36 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
de817c2
update
QuanluZhang Jan 1, 2020
9e4a3d9
update
QuanluZhang Jan 10, 2020
1db91e3
Merge branch 'master' of https://github.com/Microsoft/nni into dev-co…
QuanluZhang Jan 13, 2020
e401f2b
update
QuanluZhang Jan 15, 2020
2de52a8
add data parallel proposal (#1923)
Jan 16, 2020
7ba30c3
update
QuanluZhang Jan 20, 2020
dc865fe
update
QuanluZhang Jan 21, 2020
10c0510
update
QuanluZhang Jan 21, 2020
df1dda7
update
QuanluZhang Jan 22, 2020
9680f3e
update
QuanluZhang Jan 24, 2020
e51f288
update
Jan 28, 2020
f830430
update
Jan 28, 2020
ab7f23d
update
Jan 28, 2020
98e75c2
pass eval result validate, but has very small difference
Jan 29, 2020
ff413d1
add model_speedup.py
Jan 29, 2020
d83f190
update
Jan 30, 2020
ff7e79d
pass fpgm test
Jan 31, 2020
e1240fe
add doc for speedup
Jan 31, 2020
8d333f2
pass l1filter
Jan 31, 2020
b1b2b14
update
Jan 31, 2020
e988f19
update
Feb 1, 2020
6b0ecee
fix compressor ut (#1997)
chicm-ms Feb 5, 2020
b8da18d
Merge branch 'dev-pruner-dataparallel' of https://github.com/microsof…
Feb 5, 2020
d452a16
update lottery ticket pruner based on refactored compression code (#1…
QuanluZhang Feb 5, 2020
12485c7
Merge branch 'dev-pruner-dataparallel' of https://github.com/microsof…
Feb 5, 2020
fbb6d48
remove test files
Feb 5, 2020
1ce3c72
update
Feb 5, 2020
4db78f7
update
Feb 5, 2020
3d51727
update
Feb 5, 2020
70d3b1e
add comments
Feb 5, 2020
c80c7a9
add comments
Feb 6, 2020
005a664
add comments
Feb 6, 2020
d11a54a
add comments
Feb 6, 2020
49e0de1
update
Feb 6, 2020
951b014
resolve comments
Feb 8, 2020
280fb1b
update doc
Feb 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions examples/model_compress/MeanActivation_torch_cifar10.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -40,6 +41,12 @@ def test(model, device, test_loader):


def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")

args = parser.parse_args()
torch.manual_seed(0)
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
Expand All @@ -63,14 +70,15 @@ def main():
model.to(device)

# Train the base VGG-16 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(args.epochs):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')

# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
Expand All @@ -90,6 +98,14 @@ def main():
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = L1FilterPruner(model, configure_list)
model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
else:
print("only detect 1 gpu, fall back")

model.to(device)
test(model, device, test_loader)
# top1 = 88.19%

Expand Down
3 changes: 2 additions & 1 deletion examples/model_compress/fpgm_torch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def forward(self, x):
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
#x = x.view(-1, 4 * 4 * 50)
x = x.view(64, -1)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
Expand Down
2 changes: 2 additions & 0 deletions examples/model_compress/lottery_torch_mnist_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def test(model, test_loader, criterion):
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()

#model = nn.DataParallel(model)

for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
loss = 0
Expand Down
212 changes: 212 additions & 0 deletions examples/model_compress/model_speedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import argparse
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from models.cifar10.vgg import VGG
from nni.compression.speedup.torch import ModelSpeedup
from nni.compression.torch import apply_compression_results

torch.manual_seed(0)
use_mask = False

def apoz_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)

model = VGG(depth=16)
model.to(device)
model.eval()

dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return

def l1filter_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)

model = VGG(depth=16)
model.to(device)
model.eval()

dummy_input = torch.randn(64, 3, 32, 32)
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be removed this or replace it with logger.debug and use flag to set logging level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will do it

m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return

def fpgm_speedup(masks_file, model_checkpoint):
from fpgm_torch_mnist import Mnist
device = torch.device('cpu')
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()
model.to(device)
model.print_conv_filter_sparsity()

dummy_input = torch.randn(64, 1, 28, 28)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('mask elapsed time: ', time.time() - start)
#print(out.size(), out)
return
else:
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('speedup elapsed time: ', time.time() - start)
#print(out.size(), out)
return

def slim_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)

model = VGG(depth=19)
model.to(device)
model.eval()

dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return

if __name__ == '__main__':
parser = argparse.ArgumentParser("speedup")
parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example")
parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file")
parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model")
args = parser.parse_args()

if args.example_name == 'slim':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems *_speedup have a lot same code, may be they can be put in one method? use dynamic module import based on parser arguments?

if args.masks_file is None:
args.masks_file = 'mask_vgg19_cifar10.pth'
slim_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'fpgm':
if args.masks_file is None:
args.masks_file = 'mask.pth'
fpgm_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'l1filter':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
l1filter_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'apoz':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
apoz_speedup(args.masks_file, args.model_checkpoint)
else:
raise ValueError('unsupported example_name: {}'.format(args.example_name))
101 changes: 101 additions & 0 deletions examples/model_compress/multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import SlimPruner

class fc1(nn.Module):

def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)


self.linear1 = nn.Linear(32*28*28, 300)
self.relu2 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(300, 100)
self.relu3 = nn.ReLU(inplace=True)
self.linear3 = nn.Linear(100, num_classes)


def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)

x = torch.flatten(x,1)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
x = self.relu3(x)
x = self.linear3(x)
return x

def train(model, train_loader, optimizer, criterion, device):
model.train()
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()

def test(model, test_loader, criterion, device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy


if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fc1()

criterion = nn.CrossEntropyLoss()

configure_list = [{
'prune_iterations': 5,
'sparsity': 0.86,
'op_types': ['BatchNorm2d']
}]
pruner = SlimPruner(model, configure_list)
pruner.compress()

if torch.cuda.device_count()>1:
model = nn.DataParallel(model)

model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
for name, par in model.named_parameters():
print(name)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion, device)
accuracy = test(model, test_loader, criterion, device)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
Loading