Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dpo_trainer bug for LLMs without bos_token in config #1885

Merged
merged 4 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
pad_to_length,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
add_bos_token_if_needed,
add_eos_token_if_needed
)


Expand Down Expand Up @@ -424,25 +426,22 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
)

# add BOS token to head of prompt. Avoid adding if it's already there
bos_token_id = self.tokenizer.bos_token_id
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
self.tokenizer.bos_token_id,
prompt_len_input_ids,
prompt_tokens,
chosen_prompt_len_input_ids,
chosen_tokens,
rejected_prompt_len_input_ids,
rejected_tokens
)

# add EOS token to end of answer. Avoid adding if it's already there
eos_token_id = self.tokenizer.eos_token_id
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
self.tokenizer.eos_token_id,
chosen_tokens,
rejected_tokens
)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

Expand Down
33 changes: 16 additions & 17 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
pad_to_length,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
add_bos_token_if_needed,
add_eos_token_if_needed
)


Expand Down Expand Up @@ -857,25 +859,22 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
)

# add BOS token to head of prompt. Avoid adding if it's already there
bos_token_id = self.tokenizer.bos_token_id
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
self.tokenizer.bos_token_id,
prompt_len_input_ids,
prompt_tokens,
chosen_prompt_len_input_ids,
chosen_tokens,
rejected_prompt_len_input_ids,
rejected_tokens
)

# add EOS token to end of answer. Avoid adding if it's already there
eos_token_id = self.tokenizer.eos_token_id
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
self.tokenizer.eos_token_id,
chosen_tokens,
rejected_tokens
)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

Expand Down
33 changes: 16 additions & 17 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
pad_to_length,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
add_bos_token_if_needed,
add_eos_token_if_needed
)


Expand Down Expand Up @@ -441,25 +443,22 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
)

# add BOS token to head of prompt. Avoid adding if it's already there
bos_token_id = self.tokenizer.bos_token_id
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
self.tokenizer.bos_token_id,
prompt_len_input_ids,
prompt_tokens,
chosen_prompt_len_input_ids,
chosen_tokens,
rejected_prompt_len_input_ids,
rejected_tokens
)

# add EOS token to end of answer. Avoid adding if it's already there
eos_token_id = self.tokenizer.eos_token_id
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
self.tokenizer.eos_token_id,
chosen_tokens,
rejected_tokens
)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

Expand Down
36 changes: 36 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,3 +1163,39 @@ def batch_generation(
query_responses.append(query_response)
logitss.append(logits)
return torch.cat(query_responses, 0), torch.cat(logitss, 0)


def add_bos_token_if_needed(
bos_token_id: int,
kashif marked this conversation as resolved.
Show resolved Hide resolved
prompt_len_input_ids: int,
prompt_tokens: Dict[str, List[int]],
chosen_prompt_len_input_ids: int,
chosen_tokens: Dict[str, List[int]],
rejected_prompt_len_input_ids: int,
rejected_tokens: Dict[str, List[int]]
):
if bos_token_id is not None:
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
return prompt_tokens, chosen_tokens, rejected_tokens


def add_eos_token_if_needed(
eos_token_id: int,
chosen_tokens: Dict[str, List[int]],
rejected_tokens: Dict[str, List[int]]
):
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)
return chosen_tokens, rejected_tokens
Loading