-
Notifications
You must be signed in to change notification settings - Fork 6
/
general_tf_model.py
55 lines (46 loc) · 1.92 KB
/
general_tf_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import numpy as np
import torch.nn as nn
import tensorflow as tf
# import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
class GeneralTFModel(nn.Module):
def __init__(self, model_logits, x_input, sess, n_class=10, im_mean=None, im_std=None):
super(GeneralTFModel, self).__init__()
self.model_logits = model_logits
self.x_input = x_input
self.sess = sess
self.num_queries = 0
self.im_mean = im_mean
self.im_std = im_std
self.n_class = n_class
def forward(self, image):
if len(image.size()) != 4:
image = image.unsqueeze(0)
image_tf = np.moveaxis(image.cpu().numpy(), 1, 3)
logits = self.sess.run(self.model_logits, {self.x_input: image_tf})
return torch.from_numpy(logits).cuda()
def preprocess(self, image):
if isinstance(image, np.ndarray):
processed = torch.from_numpy(image).type(torch.FloatTensor)
else:
processed = image
if self.im_mean is not None and self.im_std is not None:
im_mean = torch.tensor(self.im_mean).cuda().view(1, processed.shape[1], 1, 1).repeat(
processed.shape[0], 1, 1, 1)
im_std = torch.tensor(self.im_std).cuda().view(1, processed.shape[1], 1, 1).repeat(
processed.shape[0], 1, 1, 1)
processed = (processed - im_mean) / im_std
return processed
def predict_prob(self, image):
if len(image.size()) != 4:
image = image.unsqueeze(0)
image = self.preprocess(image)
self.num_queries += image.size(0)
image_tf = np.moveaxis(image.cpu().numpy(), 1, 3)
logits = self.sess.run(self.model_logits, {self.x_input: image_tf})
return torch.from_numpy(logits).cuda()
def predict_label(self, image):
logits = self.predict_prob(image)
_, predict = torch.max(logits, 1)
return predict