-
Notifications
You must be signed in to change notification settings - Fork 0
/
caption_generation.py
111 lines (80 loc) · 3.02 KB
/
caption_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import pandas as pd
import torch
import torchvision.transforms as T
from collections import Counter
from models.caption.model import EncoderDecoder
from PIL import Image
import requests
import matplotlib.pyplot as plt
import translators as ts
import spacy
from generation import generate_sentence
print('FINISH IMPORTS')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
spacy_eng = spacy.load("en_core_web_sm")
class Vocabulary:
def __init__(self, freq_threshold):
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
self.stoi = {v: k for k, v in self.itos.items()}
self.freq_threshold = freq_threshold
def __len__(self):
return len(self.itos)
@staticmethod
def tokenize(text):
return [token.text.lower() for token in spacy_eng.tokenizer(text)]
def build_vocab(self, sentence_list):
frequencies = Counter()
idx = 4
for sentence in sentence_list:
for word in self.tokenize(sentence):
frequencies[word] += 1
# add the word to the vocab if it reaches minum frequecy threshold
if frequencies[word] == self.freq_threshold:
self.stoi[word] = idx
self.itos[idx] = word
idx += 1
freq_threshold = 5
captions = pd.read_csv('models/caption/captions.txt')['caption'].to_list()
vocab = Vocabulary(freq_threshold)
vocab.build_vocab(captions)
# init caption_model
caption_model = EncoderDecoder(
embed_size=300,
vocab_size=len(vocab),
attention_dim=256,
encoder_dim=2048,
decoder_dim=512,
device=device
).to(device)
caption_model.load_state_dict(torch.load(
'models/caption/attention_model_state.pt', map_location=device)['state_dict'])
transform = T.Compose([
T.Resize(226),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def generate_caption(img):
"""
Giver a url of an image, we generate its arabic caption
Input:
url (str)L the url of the image
Output:
ar_caption (str): the arabic caption
"""
# img = Image.open(requests.get(url, stream=True).raw).convert("RGB")
img = transform(img).unsqueeze(0)
caption_model.eval()
with torch.no_grad():
features = caption_model.encoder(img.to(device))
caps, _ = caption_model.decoder.generate_caption(features, vocab=vocab)
caption = ' '.join(caps[:-2])
ar_caption = ts.google(caption, from_language='en', to_language='ar')
return ar_caption
def generate_caption_sentence(img, max_lines, rhyme):
arabe_caption = generate_caption(img)
res = generate_sentence(meter='الكامل', rhyme=rhyme,
max_lines=max_lines, max_length=max_lines*50, start_with=arabe_caption)
return res
# arabe_caption = generate_caption('https://www.preventivevet.com/hubfs/Three%20dogs%20playing%20in%20the%20yard%20600%20canva.jpg')
# res = generate_sentence(meter='الكامل', rhyme='ر', start_with=arabe_caption)
# print(res)