diff --git a/machamp/model/encoder.py b/machamp/model/encoder.py index 5b7556a..11010bd 100644 --- a/machamp/model/encoder.py +++ b/machamp/model/encoder.py @@ -264,28 +264,26 @@ def embed(self, splitted_idx += 1 else: # first of the splits, keep as is - num_subwords = self.max_input_length - if self.end_token_id != None: - num_subwords -= 1 - mlm_out_merged[sent_idx][0:num_subwords] = mlm_out_split[splitted_idx][ - 0:num_subwords] + end_idx = self.max_input_length-1 if self.end_token_id == None else self.max_input_length + mlm_out_merged[sent_idx][0:end_idx] = mlm_out_split[splitted_idx][0:end_idx] + num_subwords_per_batch = self.max_input_length - self.num_extra_tokens splitted_idx += 1 # all except first and last, has no CLS/SEP for i in range(1, amount_of_splits[sent_idx] - 1): - beg = num_subwords + (i-1) * (self.max_input_length) - end = beg + self.max_input_length - self.num_extra_tokens + beg = end_idx + (i-1) * num_subwords_per_batch + end = beg + num_subwords_per_batch mlm_out_cursplit = mlm_out_split[splitted_idx] if self.end_token_id != None: mlm_out_cursplit = mlm_out_cursplit[:-1] if self.start_token_id != None: mlm_out_cursplit = mlm_out_cursplit[1:] - + mlm_out_merged[sent_idx][beg:end] = mlm_out_cursplit splitted_idx += 1 # last of the splits, keep the SEP - beg = num_subwords + (amount_of_splits[sent_idx] - 2) * (self.max_input_length - self.num_extra_tokens) + beg = end_idx + (amount_of_splits[sent_idx]-2) * num_subwords_per_batch end = lengths[sent_idx] mlm_out_merged[sent_idx][beg:end] = mlm_out_split[splitted_idx][0:end - beg] splitted_idx += 1