-
Notifications
You must be signed in to change notification settings - Fork 1
/
caption.py
120 lines (96 loc) · 4.41 KB
/
caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm
from dataset import Vocabulary
from skimage import transform
from model import *
from utils import *
import torchvision.transforms as T
from PIL import Image
import argparse
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Will only work for batch size 1
def get_all_captions(img, model, vocab=None):
features = model.EncoderCNN(img[0:1].to(device))
caps, alphas = model.DecoderLSTM.gen_captions(features, vocab=vocab)
caps = caps[:-2]
return caps
def calculate_bleu_score(dataloader, model, vocab):
candidate_corpus = []
references_corpus = []
for batch in tqdm(dataloader, total=len(dataloader)):
img, cap, all_caps = batch
img, cap = img.to(device), cap.to(device)
caps = get_all_captions(img, model, vocab)
candidate_corpus.append(caps)
references_corpus.append(all_caps[0])
assert len(candidate_corpus) == len(references_corpus)
print(f"\nBLEU1 = {corpus_bleu(references_corpus, candidate_corpus, (1, 0, 0, 0))}")
print(f"BLEU2 = {corpus_bleu(references_corpus, candidate_corpus, (0.5, 0.5, 0, 0))}")
print(f"BLEU3 = {corpus_bleu(references_corpus, candidate_corpus, (0.33, 0.33, 0.33, 0))}")
print(f"BLEU4 = {corpus_bleu(references_corpus, candidate_corpus, (0.25, 0.25, 0.25, 0.25))}")
def get_caps_from(features_tensors, model, vocab=None):
model.eval()
with torch.no_grad():
features = model.EncoderCNN(features_tensors[0:1].to(device))
caps, alphas = model.DecoderLSTM.gen_captions(features, vocab=vocab)
caption = ' '.join(caps)
show_img(features_tensors[0], caption)
return caps, alphas
def plot_attention(img, target, attention_plot):
img = img.to('cpu').numpy().transpose((1, 2, 0))
temp_image = img
fig = plt.figure(figsize=(15, 15))
len_caps = len(target)
for i in range(len_caps):
temp_att = attention_plot[i].reshape(7, 7)
temp_att = transform.pyramid_expand(temp_att, upscale=24, sigma=8)
ax = fig.add_subplot(len_caps // 2, len_caps // 2, i + 1)
ax.set_title(target[i])
img = ax.imshow(temp_image)
ax.imshow(temp_att, cmap='gray', alpha=0.5, extent=img.get_extent())
plt.tight_layout()
plt.show()
def plot_caption_with_attention(img_pth, model, transforms_=None, vocab=None):
img = Image.open(img_pth)
img = transforms_(img)
img.unsqueeze_(0)
caps, attention = get_caps_from(img, model, vocab)
plot_attention(img[0], caps, attention)
def main(arguments):
state_checkpoint = torch.load(arguments.state_chechpoint, map_location=device) # change paths
# model params
vocab = state_checkpoint['vocab']
embed_size = arguments.embed_size
embed_wts = None
vocab_size = state_checkpoint['vocab_size']
attention_dim = arguments.attention_dim
encoder_dim = arguments.encoder_dim
decoder_dim = arguments.decoder_dim
fc_dims = arguments.fc_dims
model = EncoderDecoder(embed_size,
vocab_size,
attention_dim,
encoder_dim,
decoder_dim,
fc_dims,
p=0.3,
embeddings=embed_wts).to(device)
model.load_state_dict(state_checkpoint['state_dict'])
transforms = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
img_path = arguments.image
plot_caption_with_attention(img_path, model, transforms, vocab)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
parser.add_argument('--state_checkpoint', type=str, required=True, help='path for state checkpoint')
parser.add_argument('--embed_size', type=int, default=300, help='dimension of word embedding vectors')
parser.add_argument('--attention_dim', type=int, default=256, help='dimension of attention layer')
parser.add_argument('--encoder_dim', type=int, default=2048, help='dimension of encoder layer')
parser.add_argument('--decoder_dim', type=int, default=512, help='dimension of decoder layer')
parser.add_argument('--fc_dims', type=int, default=256, help='dimension of fully connected layer')
args = parser.parse_args()
main(args)