-
Notifications
You must be signed in to change notification settings - Fork 49
/
utils.py
65 lines (51 loc) · 1.77 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch.utils.data.sampler import Sampler
from torchvision import transforms, datasets
from PIL import Image
# Dummy class to store arguments
class Dummy():
pass
# Function that opens image from disk, normalizes it and converts to tensor
read_tensor = transforms.Compose([
lambda x: Image.open(x),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
lambda x: torch.unsqueeze(x, 0)
])
# Plots image from tensor
def tensor_imshow(inp, title=None, **kwargs):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
# Mean and std for ImageNet
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp, **kwargs)
if title is not None:
plt.title(title)
# Given label number returns class name
def get_class_name(c):
labels = np.loadtxt('synset_words.txt', str, delimiter='\t')
return ' '.join(labels[c].split(',')[0].split()[1:])
# Image preprocessing function
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
# Normalization for ImageNet
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Sampler for pytorch loader. Given range r loader will only
# return dataset[r] instead of whole dataset.
class RangeSampler(Sampler):
def __init__(self, r):
self.r = r
def __iter__(self):
return iter(self.r)
def __len__(self):
return len(self.r)