Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

KeyError #115

Open
JLUGQQ opened this issue Apr 17, 2022 · 7 comments
Open

KeyError #115

JLUGQQ opened this issue Apr 17, 2022 · 7 comments

Comments

@JLUGQQ
Copy link

JLUGQQ commented Apr 17, 2022

when i run eval_biencoder, i encountered this problem:Traceback (most recent call last):
File "/data/gavin/blink-el/blink/biencoder/eval_biencoder.py", line 336, in
main(new_params)
File "/data/gavin/blink-el/blink/biencoder/eval_biencoder.py", line 289, in main
save_results,
File "/data/gavin/blink-el/blink/biencoder/nn_prediction.py", line 65, in get_topk_predictions
cand_encs=cand_encode_list[src].to(device)
KeyError: 9

@abhinavkulkarni
Copy link

abhinavkulkarni commented May 5, 2022

@JLUGQQ: Yes, this is a bug, you can look at issue #95 for the solution.

@JLUGQQ
Copy link
Author

JLUGQQ commented May 6, 2022

#95

Thank you. I have tried this solution before, but it didn't work. Maybe I should change my package version accoring to requirements.txt.

@abhinavkulkarni
Copy link

@JLUGQQ: I am able to successfully run both eval on both zeshel and non-zeshel datasets. Feel free to copy and paste your error message here, I'd be glad to take a look.

@JLUGQQ
Copy link
Author

JLUGQQ commented May 6, 2022

#95

Thank you very much for your help!
I could successfully run train_biencoder. But when I ran eval_biencoder. I encountered this problem. I have changed code according to issue #95
05/06/2022 13:33:00 - INFO - Blink - Getting top 64 predictions.
0%| | 0/2500 [00:00<?, ?it/s]05/06/2022 13:33:00 - INFO - Blink - World size : 16
0%| | 0/2500 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 337, in
main(new_params)
File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 289, in main
save_results,
File "/data/gavin/BLINK-main/blink/biencoder/nn_prediction.py", line 65, in get_topk_predictions
cand_encs=cand_encode_list[src].to(device)
KeyError: 12

@abhinavkulkarni
Copy link

@JLUGQQ: Here's what I have in git diff. Let me know if this helps.

diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
     oid = 0
     for step, batch in enumerate(iter_):
         batch = tuple(t.to(device) for t in batch)
-        context_input, _, srcs, label_ids = batch
+        if is_zeshel:
+            context_input, _, srcs, label_ids = batch
+        else:
+            context_input, _, label_ids = batch
+            srcs = torch.tensor([0] * context_input.size(0), device=device)
+
         src = srcs[0].item()
+        cand_encode_list[src] = cand_encode_list[src].to(device)
         scores = reranker.score_candidate(
             context_input, 
             None, 
-            cand_encs=cand_encode_list[src].to(device)
+            cand_encs=cand_encode_list[src]
         )
+
         values, indicies = scores.topk(top_k)
         old_src = src
         for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
                 continue
 
             # add examples in new_data
-            cur_candidates = candidate_pool[src][inds]
+            cur_candidates = candidate_pool[srcs[i].item()][inds]
             nn_context.append(context_input[i].cpu().tolist())
             nn_candidates.append(cur_candidates.cpu().tolist())
             nn_labels.append(pointer)

@JLUGQQ
Copy link
Author

JLUGQQ commented May 6, 2022

@JLUGQQ: Here's what I have in git diff. Let me know if this helps.

diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
     oid = 0
     for step, batch in enumerate(iter_):
         batch = tuple(t.to(device) for t in batch)
-        context_input, _, srcs, label_ids = batch
+        if is_zeshel:
+            context_input, _, srcs, label_ids = batch
+        else:
+            context_input, _, label_ids = batch
+            srcs = torch.tensor([0] * context_input.size(0), device=device)
+
         src = srcs[0].item()
+        cand_encode_list[src] = cand_encode_list[src].to(device)
         scores = reranker.score_candidate(
             context_input, 
             None, 
-            cand_encs=cand_encode_list[src].to(device)
+            cand_encs=cand_encode_list[src]
         )
+
         values, indicies = scores.topk(top_k)
         old_src = src
         for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
                 continue
 
             # add examples in new_data
-            cur_candidates = candidate_pool[src][inds]
+            cur_candidates = candidate_pool[srcs[i].item()][inds]
             nn_context.append(context_input[i].cpu().tolist())
             nn_candidates.append(cur_candidates.cpu().tolist())
             nn_labels.append(pointer)

Pity. It still doesn't work. Thanks for your reply. I think I should take a time to debug to find the exact reason. And I will comment if I solve this problem.

@yc-song
Copy link

yc-song commented Nov 11, 2022

KeyError might happen because the validation or test set tries to find their encodings from training set encodings. (e.g. there is a crash when val data - which has the src value 9 - attempts to find their encoding in training encodings, which has src values from 0 to 8.-- it is the reason why there is a key error for value 9)
Although there might be multiple solutions to fix this, I recommend saving each encoding in separate files. i.e. the following
shell script worked in my case:

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode train --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_train.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_train.pt

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode valid --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_valid.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_valid.pt

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode test --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_test.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_test.pt

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants