-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
43 lines (34 loc) · 1.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from caption import generate_text_embedding, generate_embedding_text
import numpy as np
import nltk.data
from scipy import spatial
# import torch
from itertools import islice
filename = input(">> Enter file to search >: ")
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
s = open(filename, 'r') # specify the file we are searching through
search_text = s.read()
sentences = [s.replace('\n', ' ').replace('\t', ' ').strip() for s in tokenizer.tokenize(search_text)]
print(">> Building search cache...")
# embeddings = {sentence: torch.flatten(generate_text_embedding(sentence)) for sentence in sentences if len(sentence) < 70}
embeddings = {sentence: generate_text_embedding(sentence).cpu().detach().numpy() for sentence in sentences if len(sentence) < 70}
locations = {sentences[i]: i for i in range(len(sentences)) if len(sentences[i]) < 70}
print(">> Completed.\n")
# cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
def search(query):
# query_embedding = torch.flatten(generate_text_embedding(query))
query_embedding = generate_text_embedding(query).cpu().detach().numpy()
# get_dist = lambda sentence: cos(query_embedding, embeddings[sentence])
get_dist = lambda sentence: spatial.distance.cosine(query_embedding.flatten(), embeddings[sentence].flatten())
# get_dist = lambda sentence: np.linalg.norm(query_embedding - embeddings[sentence])
closest_sentences = list(islice(sorted(embeddings, key=get_dist), 5))
closest_sentences = [f"{s} #{locations[s]}" for s in closest_sentences]
print(f">> Search query: {query}")
print('> ' + "\n> ".join(closest_sentences))
print("\n")
while True:
query = input(">: ")
if query == 'q':
break
print()
search(query)