-
Notifications
You must be signed in to change notification settings - Fork 1
/
app_streamlit.py
176 lines (141 loc) · 5.37 KB
/
app_streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from skimage import transform
from PIL import Image, ImageOps
import io
import streamlit as st
import os
from download_files import *
from matplotlib import pyplot as plt
import wget
from model import *
from dataset import Vocabulary
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
st.set_page_config(
initial_sidebar_state="expanded",
page_title="Explainable Image Caption Bot"
)
def transform_img(img):
transforms = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
return transforms(img)
@st.cache(ttl=86400, max_entries=15)
def download_checkpoints():
path = "./attention_model_state.pth"
if not os.path.exists(path):
with st.spinner('Downloading state checkpoint...'):
model_url = "wget -O ./attention_model_state.pth https://www.dropbox.com/s/6qw5jhumzuu4zzl/attention_model_state.pth?dl=0"
os.system(model_url)
print("Model Downloaded")
@st.cache(ttl=21600, max_entries=15)
def load_model():
state_checkpoint = torch.load("./attention_model_state.pth", map_location=device) # change paths
# model params
vocab = state_checkpoint['vocab']
embed_size = 300
embed_wts = None
vocab_size = state_checkpoint['vocab_size']
attention_dim = 256
encoder_dim = 2048
decoder_dim = 512
fc_dims = 256
model = EncoderDecoder(embed_size,
vocab_size,
attention_dim,
encoder_dim,
decoder_dim,
fc_dims,
p=0.3,
embeddings=embed_wts).to(device)
model.load_state_dict(state_checkpoint['state_dict'])
return model, vocab
def get_caps_from(features_tensors, model, vocab=None):
model.eval()
with torch.no_grad():
features = model.EncoderCNN(features_tensors[0:1].to(device))
caps, alphas = model.DecoderLSTM.gen_captions(features, vocab=vocab)
caps_temp = [c for c in caps if c not in "<EOS>"]
caption = ' '.join(caps_temp)
return caption, caps, alphas
def plot_attention(img, target, attention_plot):
img[0] = img[0] * 0.229
img[1] = img[1] * 0.224
img[2] = img[2] * 0.225
img[0] += 0.485
img[1] += 0.456
img[2] += 0.406
img = img.to('cpu').numpy().transpose((1, 2, 0))
temp_image = img
fig = plt.figure(figsize=(10, 10))
len_caps = len(target)
for i in range(len_caps):
temp_att = attention_plot[i].reshape(7, 7)
temp_att = transform.pyramid_expand(temp_att, upscale=24, sigma=8)
ax = fig.add_subplot(len_caps // 2, len_caps // 2, i + 1)
ax.set_axis_off()
ax.set_title(target[i])
img = ax.imshow(temp_image)
ax.imshow(temp_att, cmap='gray', alpha=0.8, extent=img.get_extent())
plt.tight_layout()
st.pyplot(fig)
def plot_caption_with_attention(img_pth, model, transforms_=None, vocab=None):
img = Image.open(img_pth)
img = transforms_(img)
img.unsqueeze_(0)
caption, caps, attention = get_caps_from(img, model, vocab)
st.markdown(f"## Image Caption:\n"
f" #### {caption}\n\n")
plot_attention(img[0], caps, attention)
@st.cache(ttl=3600, max_entries=10)
def load_output_image(img):
if isinstance(img, str):
image = Image.open(img)
else:
img_bytes = img.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
image = ImageOps.exif_transpose(image)
return image
if __name__ == '__main__':
download_checkpoints()
model, vocab = load_model()
st.title("The Explainable Image Captioning Bot")
st.text("")
st.text("")
st.success("Welcome! Please upload an image!"
)
img_upload = st.file_uploader(label='Upload Image', type=['png', 'jpg', 'jpeg', 'webp'])
img_pt = "imgs/test2.jpeg" if img_upload is None else img_upload
image = load_output_image(img_pt)
st.sidebar.markdown('''
### Introduction:
Hello! :hand: and Welcome,
This is a Image caption bot:\n
Its main job is to give captions :speech_balloon: or description for your
input image.\n
But we have tried something different here \n
This app gives 2 outputs
- The Caption for your image duh? :upside_down_face:
- Explaination as in a image grid i.e. the parts of the image
where the AI looks when trying to caption your image :nerd_face: \n
### Explaination of Captions:
The white regions in each image of the resultant grid of images shows the regions where the deep learning models
focuses to get that perticular word mentioned in the image title.\n
### Instructions to use
If you are getting random captions, then try :-
- Using a PC
- Try images of bicycles or motarbikes
- Try images with children in it
- Try images of dogs
''')
st.sidebar.markdown('''Check the model details [here](https://github.com/mrFahrenhiet/Explainable_Image_Caption_bot)
\n Liked it? Give a :star: on GitHub ''')
st.image(image, use_column_width=True)
if st.button('Generate captions!'):
plot_caption_with_attention(img_pt, model, transform_img, vocab)
st.success("Try a different image by uploading")
st.balloons()