Skip to content

Commit

Permalink
Merge pull request #86 from 01-ai/jiangcheng_dev
Browse files Browse the repository at this point in the history
fix sft loss promlem
  • Loading branch information
ZhaoFancy authored Nov 9, 2023
2 parents 2320d7d + 983b794 commit 6d68516
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
21 changes: 11 additions & 10 deletions finetune/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __getitem__(self, idx):
return {
"input_ids": self.chosen_dataset[idx]["input_ids"],
"attention_mask": self.chosen_dataset[idx]["attention_mask"],
"labels": self.chosen_dataset[idx]["input_ids"],
"labels": self.chosen_dataset[idx]["labels"]
}


Expand All @@ -148,10 +148,9 @@ def create_dataset_split(
if train_phase == SFT:
for i, tmp_data in enumerate(current_dataset):
# tokenize the text
chosen_sentence = raw_dataset.get_prompt_and_chosen(
tmp_data
) # the accept response
if chosen_sentence is not None:
chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data) # the accept response
prompt_sentence = raw_dataset.get_prompt(tmp_data)
if chosen_sentence is not None and prompt_sentence is not None:
chosen_sentence += end_of_conversation_token
chosen_token = tokenizer(
chosen_sentence,
Expand All @@ -161,9 +160,11 @@ def create_dataset_split(
return_tensors="pt",
)
chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(0)
chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(
0
)
chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(0)
prompt_token = tokenizer(prompt_sentence, add_special_tokens=False)
prompt_token_len = min(max_seq_len, len(prompt_token["input_ids"]))
chosen_token["labels"] = chosen_token["input_ids"].clone()
chosen_token["labels"][:prompt_token_len] = -100
chosen_dataset.append(chosen_token)

return PromptDataset(
Expand Down Expand Up @@ -452,7 +453,7 @@ def __init__(self, max_size, small_batch_size):
self.max_size = max_size
self.small_batch_size = small_batch_size

def separate(self):
def seperate(self):
small_dataset = []
for large_batch in self.dataset:
if type(large_batch) == list or type(large_batch) == tuple:
Expand Down Expand Up @@ -483,7 +484,7 @@ def add(self, data):
if len(self.dataset) < self.max_size:
self.dataset.append(data)
if len(self.dataset) == self.max_size:
return self.separate()
return self.seperate()
else:
return None
else:
Expand Down
7 changes: 6 additions & 1 deletion finetune/utils/data/raw_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_prompt_and_rejected(self, sample):
class YiDataset(PromptRawDataset):
def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
super().__init__(output_path, seed, local_rank, dataset_name)
print("chat path is {}".format(chat_path))
print("data path is {}".format(chat_path))
self.dataset_name = "yi"
self.dataset_name_clean = "yi"
self.raw_datasets = load_dataset(
Expand All @@ -154,6 +154,11 @@ def get_eval_data(self):
if self.raw_datasets["eval"] is not None:
return self.raw_datasets["eval"]
return None

def get_prompt(self, sample):
if sample["prompt"] is not None:
return " " + sample["prompt"]
return None

def get_prompt_and_chosen(self, sample):
if sample["prompt"] is not None and sample["chosen"] is not None:
Expand Down

0 comments on commit 6d68516

Please sign in to comment.