From 4ed976e2819a1695641357137475efdcbdd5552b Mon Sep 17 00:00:00 2001 From: Doxie Date: Mon, 29 Jul 2024 19:48:35 +0800 Subject: [PATCH 1/4] fix dpo_trainer bug for LLMs without bos_token in config --- trl/trainer/dpo_trainer.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8674ceff57..61dfa0c4ca 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -858,15 +858,16 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module # add BOS token to head of prompt. Avoid adding if it's already there bos_token_id = self.tokenizer.bos_token_id - if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: - prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] - prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] - if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: - chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] - chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] - if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: - rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] - rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + if bos_token_id is not None: + if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: + prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: + chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: + rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] # add EOS token to end of answer. Avoid adding if it's already there eos_token_id = self.tokenizer.eos_token_id From 6ce5df003f052203c5863315813fec195e753b79 Mon Sep 17 00:00:00 2001 From: DZ9 Date: Wed, 31 Jul 2024 14:16:25 +0800 Subject: [PATCH 2/4] fix adding bos_token_id bug in dpo,orpo,cpo trainers --- trl/trainer/cpo_trainer.py | 33 ++++++++++++++++----------------- trl/trainer/dpo_trainer.py | 34 ++++++++++++++++------------------ trl/trainer/orpo_trainer.py | 33 ++++++++++++++++----------------- trl/trainer/utils.py | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 52 deletions(-) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 8248c7c75d..b41585e984 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -41,6 +41,8 @@ pad_to_length, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, + add_bos_token_if_needed, + add_eos_token_if_needed ) @@ -424,25 +426,22 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module ) # add BOS token to head of prompt. Avoid adding if it's already there - bos_token_id = self.tokenizer.bos_token_id - if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: - prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] - prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] - if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: - chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] - chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] - if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: - rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] - rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.tokenizer.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens + ) # add EOS token to end of answer. Avoid adding if it's already there - eos_token_id = self.tokenizer.eos_token_id - if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: - chosen_tokens["input_ids"].append(eos_token_id) - chosen_tokens["attention_mask"].append(1) - if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: - rejected_tokens["input_ids"].append(eos_token_id) - rejected_tokens["attention_mask"].append(1) + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.tokenizer.eos_token_id, + chosen_tokens, + rejected_tokens + ) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 61dfa0c4ca..decc8ed7af 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -53,6 +53,8 @@ pad_to_length, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, + add_bos_token_if_needed, + add_eos_token_if_needed ) @@ -857,26 +859,22 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module ) # add BOS token to head of prompt. Avoid adding if it's already there - bos_token_id = self.tokenizer.bos_token_id - if bos_token_id is not None: - if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: - prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] - prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] - if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: - chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] - chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] - if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: - rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] - rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.tokenizer.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens + ) # add EOS token to end of answer. Avoid adding if it's already there - eos_token_id = self.tokenizer.eos_token_id - if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: - chosen_tokens["input_ids"].append(eos_token_id) - chosen_tokens["attention_mask"].append(1) - if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: - rejected_tokens["input_ids"].append(eos_token_id) - rejected_tokens["attention_mask"].append(1) + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.tokenizer.eos_token_id, + chosen_tokens, + rejected_tokens + ) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 58f8eccf8c..ec8d6ce04d 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -45,6 +45,8 @@ pad_to_length, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, + add_bos_token_if_needed, + add_eos_token_if_needed ) @@ -441,25 +443,22 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module ) # add BOS token to head of prompt. Avoid adding if it's already there - bos_token_id = self.tokenizer.bos_token_id - if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: - prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] - prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] - if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: - chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] - chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] - if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: - rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] - rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.tokenizer.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens + ) # add EOS token to end of answer. Avoid adding if it's already there - eos_token_id = self.tokenizer.eos_token_id - if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: - chosen_tokens["input_ids"].append(eos_token_id) - chosen_tokens["attention_mask"].append(1) - if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: - rejected_tokens["input_ids"].append(eos_token_id) - rejected_tokens["attention_mask"].append(1) + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.tokenizer.eos_token_id, + chosen_tokens, + rejected_tokens + ) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 15035ac0fe..4181f6ba66 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1163,3 +1163,39 @@ def batch_generation( query_responses.append(query_response) logitss.append(logits) return torch.cat(query_responses, 0), torch.cat(logitss, 0) + + +def add_bos_token_if_needed( + bos_token_id: int, + prompt_len_input_ids: int, + prompt_tokens: Dict[str, List[int]], + chosen_prompt_len_input_ids: int, + chosen_tokens: Dict[str, List[int]], + rejected_prompt_len_input_ids: int, + rejected_tokens: Dict[str, List[int]] +): + if bos_token_id is not None: + if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: + prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: + chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: + rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + return prompt_tokens, chosen_tokens, rejected_tokens + + +def add_eos_token_if_needed( + eos_token_id: int, + chosen_tokens: Dict[str, List[int]], + rejected_tokens: Dict[str, List[int]] +): + if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: + chosen_tokens["input_ids"].append(eos_token_id) + chosen_tokens["attention_mask"].append(1) + if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: + rejected_tokens["input_ids"].append(eos_token_id) + rejected_tokens["attention_mask"].append(1) + return chosen_tokens, rejected_tokens \ No newline at end of file From d8c887eae3d5aec2e5da893a34d43c366380519a Mon Sep 17 00:00:00 2001 From: DZ9 Date: Wed, 31 Jul 2024 15:52:15 +0800 Subject: [PATCH 3/4] formatting for fixing bos_token adding bug --- trl/trainer/cpo_trainer.py | 10 ++++------ trl/trainer/dpo_trainer.py | 10 ++++------ trl/trainer/orpo_trainer.py | 10 ++++------ trl/trainer/utils.py | 8 +++----- 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index b41585e984..33e49b351c 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -37,12 +37,12 @@ from .cpo_config import CPOConfig from .utils import ( DPODataCollatorWithPadding, + add_bos_token_if_needed, + add_eos_token_if_needed, disable_dropout_in_model, pad_to_length, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, - add_bos_token_if_needed, - add_eos_token_if_needed ) @@ -433,14 +433,12 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module chosen_prompt_len_input_ids, chosen_tokens, rejected_prompt_len_input_ids, - rejected_tokens + rejected_tokens, ) # add EOS token to end of answer. Avoid adding if it's already there chosen_tokens, rejected_tokens = add_eos_token_if_needed( - self.tokenizer.eos_token_id, - chosen_tokens, - rejected_tokens + self.tokenizer.eos_token_id, chosen_tokens, rejected_tokens ) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index decc8ed7af..ebd4e2c49f 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -48,13 +48,13 @@ from .utils import ( DPODataCollatorWithPadding, RunningMoments, + add_bos_token_if_needed, + add_eos_token_if_needed, cap_exp, disable_dropout_in_model, pad_to_length, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, - add_bos_token_if_needed, - add_eos_token_if_needed ) @@ -866,14 +866,12 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module chosen_prompt_len_input_ids, chosen_tokens, rejected_prompt_len_input_ids, - rejected_tokens + rejected_tokens, ) # add EOS token to end of answer. Avoid adding if it's already there chosen_tokens, rejected_tokens = add_eos_token_if_needed( - self.tokenizer.eos_token_id, - chosen_tokens, - rejected_tokens + self.tokenizer.eos_token_id, chosen_tokens, rejected_tokens ) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index ec8d6ce04d..2eee27e4cc 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -41,12 +41,12 @@ from .orpo_config import ORPOConfig from .utils import ( DPODataCollatorWithPadding, + add_bos_token_if_needed, + add_eos_token_if_needed, disable_dropout_in_model, pad_to_length, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, - add_bos_token_if_needed, - add_eos_token_if_needed ) @@ -450,14 +450,12 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module chosen_prompt_len_input_ids, chosen_tokens, rejected_prompt_len_input_ids, - rejected_tokens + rejected_tokens, ) # add EOS token to end of answer. Avoid adding if it's already there chosen_tokens, rejected_tokens = add_eos_token_if_needed( - self.tokenizer.eos_token_id, - chosen_tokens, - rejected_tokens + self.tokenizer.eos_token_id, chosen_tokens, rejected_tokens ) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 4181f6ba66..541bcd5926 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1172,7 +1172,7 @@ def add_bos_token_if_needed( chosen_prompt_len_input_ids: int, chosen_tokens: Dict[str, List[int]], rejected_prompt_len_input_ids: int, - rejected_tokens: Dict[str, List[int]] + rejected_tokens: Dict[str, List[int]], ): if bos_token_id is not None: if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: @@ -1188,9 +1188,7 @@ def add_bos_token_if_needed( def add_eos_token_if_needed( - eos_token_id: int, - chosen_tokens: Dict[str, List[int]], - rejected_tokens: Dict[str, List[int]] + eos_token_id: int, chosen_tokens: Dict[str, List[int]], rejected_tokens: Dict[str, List[int]] ): if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: chosen_tokens["input_ids"].append(eos_token_id) @@ -1198,4 +1196,4 @@ def add_eos_token_if_needed( if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: rejected_tokens["input_ids"].append(eos_token_id) rejected_tokens["attention_mask"].append(1) - return chosen_tokens, rejected_tokens \ No newline at end of file + return chosen_tokens, rejected_tokens From 931e2b572fa3f5f8784f53e8aeeea1c0098e8b14 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 31 Jul 2024 10:19:10 +0200 Subject: [PATCH 4/4] Update trl/trainer/utils.py --- trl/trainer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 541bcd5926..3eaddf0ad1 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1166,7 +1166,7 @@ def batch_generation( def add_bos_token_if_needed( - bos_token_id: int, + bos_token_id: Optional[int], prompt_len_input_ids: int, prompt_tokens: Dict[str, List[int]], chosen_prompt_len_input_ids: int,