Skip to content

Commit

Permalink
siompler sample script
Browse files Browse the repository at this point in the history
  • Loading branch information
alvinwan committed Mar 23, 2020
1 parent 71555e3 commit c8926e0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
26 changes: 6 additions & 20 deletions nbdt/bin/nbdt
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@

from nbdt.model import SoftNBDT
from pytorchcv.models.wrn_cifar import wrn28_10_cifar10
from PIL import Image
from urllib.request import urlopen, Request
from torchvision import transforms
import io
from nbdt.utils import DATASET_TO_CLASSES
import sys

assert len(sys.argv) > 1, "Need to pass image URL or image path as argument"

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# load pretrained NBDT
model = wrn28_10_cifar10()
model = SoftNBDT(
Expand All @@ -22,27 +18,17 @@ model = SoftNBDT(
pretrained=True,
arch='wrn28_10_cifar10')

# load image
path = sys.argv[1]
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.3'
}
if 'http' in path:
request = Request(path, headers=headers)
file = io.BytesIO(urlopen(request).read())
else:
file = path
im = Image.open(file)

# transform image
# load + transform image
im = load_image_from_path(sys.argv[1])
transforms = transforms.Compose([
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor()
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
x = transforms(im)[None]

# run inference
outputs = model(x)
cls = classes[outputs[0]]
cls = DATASET_TO_CLASSES['CIFAR10'][outputs[0]]
print(cls)
22 changes: 22 additions & 0 deletions nbdt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import math
import numpy as np

from urllib.request import urlopen, Request
from PIL import Image
import torch.nn as nn
import torch.nn.init as init
from pathlib import Path
import io

# tree-generation consntants
METHODS = ('wordnet', 'random', 'induced')
Expand All @@ -22,6 +25,12 @@
'TinyImagenet200': 200,
'Imagenet1000': 1000
}
DATASET_TO_CLASSES = {
'CIFAR10': [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck'
]
}


def fwd():
Expand Down Expand Up @@ -61,6 +70,19 @@ def populate_kwargs(args, kwargs, object, name='Dataset', keys=(), globals={}):
f'{key}: {value}')


def load_image_from_path(path):
"""Path can be local or a URL"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.3'
}
if 'http' in path:
request = Request(path, headers=headers)
file = io.BytesIO(urlopen(request).read())
else:
file = path
return Image.open(file)


class Colors:
RED = '\x1b[31m'
GREEN = '\x1b[32m'
Expand Down

0 comments on commit c8926e0

Please sign in to comment.