forked from microsoft/nni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #215 from microsoft/master
Filter prune algo implementation (microsoft#1655)
- Loading branch information
Showing
10 changed files
with
557 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
assert tf.__version__ >= "2.0" | ||
import numpy as np | ||
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout | ||
from nni.compression.tensorflow import FPGMPruner | ||
|
||
def get_data(): | ||
(X_train_full, y_train_full), _ = keras.datasets.mnist.load_data() | ||
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:] | ||
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:] | ||
|
||
X_mean = X_train.mean(axis=0, keepdims=True) | ||
X_std = X_train.std(axis=0, keepdims=True) + 1e-7 | ||
X_train = (X_train - X_mean) / X_std | ||
X_valid = (X_valid - X_mean) / X_std | ||
|
||
X_train = X_train[..., np.newaxis] | ||
X_valid = X_valid[..., np.newaxis] | ||
|
||
return X_train, X_valid, y_train, y_valid | ||
|
||
def get_model(): | ||
model = keras.models.Sequential([ | ||
Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), | ||
MaxPooling2D(pool_size=2), | ||
Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"), | ||
MaxPooling2D(pool_size=2), | ||
Flatten(), | ||
Dense(units=128, activation='relu'), | ||
Dropout(0.5), | ||
Dense(units=10, activation='softmax'), | ||
]) | ||
model.compile(loss="sparse_categorical_crossentropy", | ||
optimizer=keras.optimizers.SGD(lr=1e-3), | ||
metrics=["accuracy"]) | ||
return model | ||
|
||
def main(): | ||
X_train, X_valid, y_train, y_valid = get_data() | ||
model = get_model() | ||
|
||
configure_list = [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2D'] | ||
}] | ||
pruner = FPGMPruner(model, configure_list) | ||
pruner.compress() | ||
|
||
update_epoch_callback = keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: pruner.update_epoch(epoch)) | ||
|
||
model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid), callbacks=[update_epoch_callback]) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from nni.compression.torch import FPGMPruner | ||
import torch | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
|
||
|
||
class Mnist(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) | ||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) | ||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) | ||
self.fc2 = torch.nn.Linear(500, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.conv1(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = F.relu(self.conv2(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = x.view(-1, 4 * 4 * 50) | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
def _get_conv_weight_sparsity(self, conv_layer): | ||
num_zero_filters = (conv_layer.weight.data.sum((2,3)) == 0).sum() | ||
num_filters = conv_layer.weight.data.size(0) * conv_layer.weight.data.size(1) | ||
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters | ||
|
||
def print_conv_filter_sparsity(self): | ||
conv1_data = self._get_conv_weight_sparsity(self.conv1) | ||
conv2_data = self._get_conv_weight_sparsity(self.conv2) | ||
print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2])) | ||
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2])) | ||
|
||
def train(model, device, train_loader, optimizer): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
if batch_idx % 100 == 0: | ||
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) | ||
model.print_conv_filter_sparsity() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%)\n'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
|
||
|
||
def main(): | ||
torch.manual_seed(0) | ||
device = torch.device('cpu') | ||
|
||
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=True, download=True, transform=trans), | ||
batch_size=64, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=False, transform=trans), | ||
batch_size=1000, shuffle=True) | ||
|
||
model = Mnist() | ||
model.print_conv_filter_sparsity() | ||
|
||
'''you can change this to LevelPruner to implement it | ||
pruner = LevelPruner(configure_list) | ||
''' | ||
configure_list = [{ | ||
'sparsity': 0.5, | ||
'op_types': ['Conv2d'] | ||
}] | ||
|
||
pruner = FPGMPruner(model, configure_list) | ||
pruner.compress() | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) | ||
for epoch in range(10): | ||
pruner.update_epoch(epoch) | ||
print('# Epoch {} #'.format(epoch)) | ||
train(model, device, train_loader, optimizer) | ||
test(model, device, test_loader) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.