Skip to content

Commit

Permalink
[BugFix] Fix bugs in oagbert.encode_paper (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
THINK2TRY authored Sep 3, 2022
1 parent c3d23bb commit e4957e9
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions cogdl/oag/oagbert_metainfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,15 @@ def encode_paper(

split_index = {"text": [], "venue": [], "authors": [], "concepts": [], "affiliations": []}

device = next(self.parameters()).device
sequence_output, pooled_output = self.bert.forward(
input_ids=torch.LongTensor(input_ids).unsqueeze(0),
token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0),
attention_mask=torch.LongTensor(input_masks).unsqueeze(0),
input_ids=torch.LongTensor(input_ids).unsqueeze(0).to(device),
token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0).to(device),
attention_mask=torch.LongTensor(input_masks).unsqueeze(0).to(device),
output_all_encoded_layers=False,
checkpoint_activations=False,
position_ids=torch.LongTensor(position_ids).unsqueeze(0),
position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0),
position_ids=torch.LongTensor(position_ids).unsqueeze(0).to(device),
position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0).to(device),
)

entities = {0: "text", 2: "venue", 1: "authors", 4: "concepts", 3: "affiliations"}
Expand Down

0 comments on commit e4957e9

Please sign in to comment.