-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsemantic_search.py
28 lines (21 loc) · 965 Bytes
/
semantic_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from sentence_transformers import SentenceTransformer, util
import torch
import pickle
with open('embeddings.pkl', "rb") as fIn:
stored_data = pickle.load(fIn)
stored_sentences = stored_data['sentences']
stored_embeddings = stored_data['embeddings']
corpus_embeddings = torch.tensor(stored_embeddings[2000*13:])
query_embeddings = torch.tensor(stored_embeddings[:13*2000])
corpus_embeddings = corpus_embeddings.to('cuda')
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
query_embeddings = query_embeddings.to('cuda')
query_embeddings = util.normalize_embeddings(query_embeddings)
hits = util.semantic_search(query_embeddings, corpus_embeddings, score_function=util.dot_score, top_k=64)
# print(hits)
for i, hit in enumerate(hits):
# print("Query:", stored_sentences[i])
for k in range(len(hit)):
print(hit[k]['corpus_id']+2000*13, end=" ")
# print(hit[k]['score'])
print("")