diff --git a/xtuner/dataset/collate_fns/default_collate_fn.py b/xtuner/dataset/collate_fns/default_collate_fn.py index 4ff6785a6..a2bf20caf 100644 --- a/xtuner/dataset/collate_fns/default_collate_fn.py +++ b/xtuner/dataset/collate_fns/default_collate_fn.py @@ -39,6 +39,7 @@ def default_collate_fn(instances: Sequence[Dict], if has_image: pixel_values.append(example['pixel_values']) + ori_length = [len(ids) for ids in input_ids] if len(instances) > 1: input_ids = pad_sequence( input_ids, batch_first=True, padding_value=pad_index) @@ -53,7 +54,12 @@ def default_collate_fn(instances: Sequence[Dict], attention_mask = None position_ids = torch.stack(position_ids, dim=0) else: - attention_mask = input_ids.ne(pad_index) + # Some tokenizers have the same eos token and pad token, so input_ids + # cannot be masked directly based on the pad token id. + attention_mask = torch.zeros_like(input_ids).bool() + for i in ori_length: + attention_mask[:i] = True + bs, seq_len = input_ids.shape position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)