From 84dcaccefac173d279c051fa20f18a04b8589d79 Mon Sep 17 00:00:00 2001 From: think2try Date: Sat, 3 Sep 2022 15:56:55 +0800 Subject: [PATCH] fix bugs in oagbert.encode_paper --- cogdl/oag/oagbert_metainfo.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cogdl/oag/oagbert_metainfo.py b/cogdl/oag/oagbert_metainfo.py index 23f13782..7c1f1142 100644 --- a/cogdl/oag/oagbert_metainfo.py +++ b/cogdl/oag/oagbert_metainfo.py @@ -342,14 +342,15 @@ def encode_paper( split_index = {"text": [], "venue": [], "authors": [], "concepts": [], "affiliations": []} + device = next(self.parameters()).device sequence_output, pooled_output = self.bert.forward( - input_ids=torch.LongTensor(input_ids).unsqueeze(0), - token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0), - attention_mask=torch.LongTensor(input_masks).unsqueeze(0), + input_ids=torch.LongTensor(input_ids).unsqueeze(0).to(device), + token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0).to(device), + attention_mask=torch.LongTensor(input_masks).unsqueeze(0).to(device), output_all_encoded_layers=False, checkpoint_activations=False, - position_ids=torch.LongTensor(position_ids).unsqueeze(0), - position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0), + position_ids=torch.LongTensor(position_ids).unsqueeze(0).to(device), + position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0).to(device), ) entities = {0: "text", 2: "venue", 1: "authors", 4: "concepts", 3: "affiliations"}