Skip to content

Commit

Permalink
Fixed indexing of split input
Browse files Browse the repository at this point in the history
  • Loading branch information
robvanderg committed Sep 22, 2022
1 parent 9874a41 commit db4d4ef
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions machamp/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,28 +264,26 @@ def embed(self,
splitted_idx += 1
else:
# first of the splits, keep as is
num_subwords = self.max_input_length
if self.end_token_id != None:
num_subwords -= 1
mlm_out_merged[sent_idx][0:num_subwords] = mlm_out_split[splitted_idx][
0:num_subwords]
end_idx = self.max_input_length-1 if self.end_token_id == None else self.max_input_length
mlm_out_merged[sent_idx][0:end_idx] = mlm_out_split[splitted_idx][0:end_idx]
num_subwords_per_batch = self.max_input_length - self.num_extra_tokens

splitted_idx += 1
# all except first and last, has no CLS/SEP
for i in range(1, amount_of_splits[sent_idx] - 1):
beg = num_subwords + (i-1) * (self.max_input_length)
end = beg + self.max_input_length - self.num_extra_tokens
beg = end_idx + (i-1) * num_subwords_per_batch
end = beg + num_subwords_per_batch
mlm_out_cursplit = mlm_out_split[splitted_idx]
if self.end_token_id != None:
mlm_out_cursplit = mlm_out_cursplit[:-1]
if self.start_token_id != None:
mlm_out_cursplit = mlm_out_cursplit[1:]

mlm_out_merged[sent_idx][beg:end] = mlm_out_cursplit
splitted_idx += 1

# last of the splits, keep the SEP
beg = num_subwords + (amount_of_splits[sent_idx] - 2) * (self.max_input_length - self.num_extra_tokens)
beg = end_idx + (amount_of_splits[sent_idx]-2) * num_subwords_per_batch
end = lengths[sent_idx]
mlm_out_merged[sent_idx][beg:end] = mlm_out_split[splitted_idx][0:end - beg]
splitted_idx += 1
Expand Down

0 comments on commit db4d4ef

Please sign in to comment.