Skip to content

Commit

Permalink
Add DPO support in finetuning microservice (#857)
Browse files Browse the repository at this point in the history
* added dpo support.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* make dpo trainer compatible with newest transformers.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* added ut for dpo.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* added training successfulness check in finetuning ut.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updated broken link.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>

---------

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ZePan110 <ze.pan@intel.com>
  • Loading branch information
3 people authored Nov 12, 2024
1 parent 9a50131 commit 37f3514
Show file tree
Hide file tree
Showing 6 changed files with 656 additions and 37 deletions.
27 changes: 26 additions & 1 deletion comps/finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Download a training file, such as `alpaca_data.json` for instruction tuning and
curl http://${your_ip}:8015/v1/files -X POST -H "Content-Type: multipart/form-data" -F "file=@./alpaca_data.json" -F purpose="fine-tune"
```

For reranking and embedding models finetuning, the training file [toy_finetune_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/blob/master/examples/finetune/toy_finetune_data.jsonl) is an toy example.
For reranking and embedding models finetuning, the training file [toy_finetune_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/blob/1.1/examples/finetune/toy_finetune_data.jsonl) is an toy example.

### 3.2 Create fine-tuning job

Expand Down Expand Up @@ -199,6 +199,31 @@ Below is an example for the format of the pretraining dataset:
{"text": "A boy with a blue tank top sitting watching three dogs."}
```

#### 3.2.5 Direct Preference Optimization (DPO)

Use the following command to launch a job for LLM Direct Preference Optimization, such as `meta-llama/Llama-2-7b-hf`:

```bash
# create a finetuning job
curl http://${your_ip}:8015/v1/fine_tuning/jobs \
-X POST \
-H "Content-Type: application/json" \
-d '{
"training_file": "test_data.jsonl",
"model": "meta-llama/Llama-2-7b-hf",
"General":{
"task":"dpo"
}
}'
```

Below is an example for the format of the DPO dataset:

```json
{"system": "", "question": "You will be given a definition of a task first, then some input of the task.\nThis task is about using the specified sentence and converting the sentence to Resource Description Framework (RDF) triplets of the form (subject, predicate object). The RDF triplets generated must be such that the triplets accurately capture the structure and semantics of the input sentence. The input is a sentence and the output is a list of triplets of the form [subject, predicate, object] that capture the relationships present in the sentence. When a sentence has more than 1 RDF triplet possible, the output must contain all of them.\n\nAFC Ajax (amateurs)'s ground is Sportpark De Toekomst where Ajax Youth Academy also play.\nOutput:", "chosen": "[\n [\"AFC Ajax (amateurs)\", \"has ground\", \"Sportpark De Toekomst\"],\n [\"Ajax Youth Academy\", \"plays at\", \"Sportpark De Toekomst\"]\n]", "rejected": " Sure, I'd be happy to help! Here are the RDF triplets for the input sentence:\n\n[AFC Ajax (amateurs), hasGround, Sportpark De Toekomst]\n[Ajax Youth Academy, playsAt, Sportpark De Toekomst]\n\nExplanation:\n\n* AFC Ajax (amateurs) is the subject of the first triplet, and hasGround is the predicate that describes the relationship between AFC Ajax (amateurs) and Sportpark De Toekomst.\n* Ajax Youth Academy is the subject of the second triplet, and playsAt is the predicate that describes the relationship between Ajax Youth Academy and Sportpark De Toekomst.\n\nNote that there may be other possible RDF triplets that could be derived from the input sentence, but the above triplets capture the main relationships present in the sentence."}
{"system": "You are an AI assistant. You will be given a task. You must generate a detailed and long answer.", "question": "Generate an approximately fifteen-word sentence that describes all this data: Midsummer House eatType restaurant; Midsummer House food Chinese; Midsummer House priceRange moderate; Midsummer House customer rating 3 out of 5; Midsummer House near All Bar One", "chosen": "Midsummer House is a moderately priced Chinese restaurant with a 3/5 customer rating, located near All Bar One.", "rejected": " Sure! Here's a sentence that describes all the data you provided:\n\n\"Midsummer House is a moderately priced Chinese restaurant with a customer rating of 3 out of 5, located near All Bar One, offering a variety of delicious dishes.\""}
```

### 3.3 Manage fine-tuning job

Below commands show how to list finetuning jobs, retrieve a finetuning job, cancel a finetuning job and list checkpoints of a finetuning job.
Expand Down
5 changes: 4 additions & 1 deletion comps/finetuning/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def check_report_to(cls, v: str):

@validator("task")
def check_task(cls, v: str):
assert v in ["instruction_tuning", "pretraining", "rerank", "embedding"]
assert v in ["instruction_tuning", "pretraining", "dpo", "rerank", "embedding"]
return v


Expand All @@ -71,11 +71,13 @@ class DatasetConfig(BaseModel):
block_size: int = 512
shuffle: bool = False
max_source_length: int = 384
max_prompt_length: int = 512
padding_side: str = "right"
truncation_side: str = "right"
max_seq_length: int = 512
truncation: bool = True
padding: Union[bool, str] = True
pad_to_max: bool = False
mask_input: bool = True
mask_response: bool = True
data_preprocess_type: str = "neural_chat"
Expand Down Expand Up @@ -132,6 +134,7 @@ class TrainingConfig(BaseModel):
logging_steps: int = 10
deepspeed_config_file: str = ""
embedding_training_config: Optional[EmbeddingTrainingConfig] = EmbeddingTrainingConfig()
dpo_beta: float = Field(default=0.1, description="the beta parameter for DPO loss")

@validator("device")
def check_device(cls, v: str):
Expand Down
121 changes: 121 additions & 0 deletions comps/finetuning/llm_on_ray/finetune/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,102 @@ def tokenize(self, examples):
return examples


class DPODataProcessor:
def __init__(self, config, tokenizer):
self.tokenizer = tokenizer
self.max_length = config["Dataset"].get("max_length", 1024)
self.max_prompt_length = config["Dataset"].get("max_prompt_length", 512)
self.pad_to_max = config["Dataset"].get("pad_to_max", False)

def tokenize(self, examples):
prompts = {(system + question).strip() for system, question in zip(examples["system"], examples["question"])}
chosens = {c.strip() for c in examples["chosen"]}
rejects = {r.strip() for r in examples["rejected"]}

examples = {
"prompt": [],
"chosen": [],
"rejected": [],
"chosen_response_only": [],
"rejected_response_only": [],
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
"prompt_input_ids": [],
"prompt_attention_mask": [],
}

for prompt, chosen, reject in zip(prompts, chosens, rejects):

prompt_tokens = self.tokenizer.tokenize(prompt)

if len(prompt_tokens) > self.max_prompt_length:
prompt_tokens = prompt_tokens[: self.max_prompt_length]

prompt_ids = self.tokenizer.convert_tokens_to_ids(prompt_tokens)
prompt_mask = [1] * len(prompt_ids)

max_resp = self.max_length - len(prompt_ids)
chosen_tokens = self.tokenizer.tokenize(chosen)
chosen_tokens = chosen_tokens[: max_resp - 1]
chosen_tokens.append(self.tokenizer.eos_token)
chosen_ids = self.tokenizer.convert_tokens_to_ids(chosen_tokens)
chosen_mask = [1] * len(chosen_ids)

reject_tokens = self.tokenizer.tokenize(reject)
reject_tokens = reject_tokens[: max_resp - 1]
reject_tokens.append(self.tokenizer.eos_token)
reject_ids = self.tokenizer.convert_tokens_to_ids(reject_tokens)
reject_mask = [1] * len(reject_ids)

chosen_input_ids = prompt_ids + chosen_ids
chosen_attention_mask = prompt_mask + chosen_mask
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids

reject_input_ids = prompt_ids + reject_ids
reject_attention_mask = prompt_mask + reject_mask
reject_labels = [IGNORE_INDEX] * len(prompt_ids) + reject_ids

# padding
input_len = len(chosen_input_ids)
if self.pad_to_max:
pad_len = self.max_length - input_len
chosen_input_ids = chosen_input_ids + [0] * pad_len
chosen_labels = chosen_labels + [-100] * pad_len
chosen_attention_mask = chosen_attention_mask + [0] * pad_len
assert len(chosen_input_ids) == self.max_length

input_len = len(reject_input_ids)
if self.pad_to_max:
pad_len = self.max_length - input_len
reject_input_ids = reject_input_ids + [0] * pad_len
reject_labels = reject_labels + [-100] * pad_len
reject_attention_mask = reject_attention_mask + [0] * pad_len
assert len(reject_input_ids) == self.max_length

examples["prompt"].append(prompt)
examples["chosen"].append(prompt + chosen)
examples["rejected"].append(prompt + reject)
examples["chosen_response_only"].append(chosen)
examples["rejected_response_only"].append(reject)

examples["chosen_input_ids"].append(chosen_input_ids)
examples["chosen_attention_mask"].append(chosen_attention_mask)
examples["chosen_labels"].append(chosen_labels)

examples["rejected_input_ids"].append(reject_input_ids)
examples["rejected_attention_mask"].append(reject_attention_mask)
examples["rejected_labels"].append(reject_labels)

examples["prompt_input_ids"].append(prompt_ids)
examples["prompt_attention_mask"].append(prompt_mask)

return examples


class TrainDatasetForCE(Dataset):
def __init__(self, dataset, args, tokenizer):
self.dataset = dataset
Expand Down Expand Up @@ -350,3 +446,28 @@ def __call__(self, features):
return_tensors="pt",
)
return {"query": q_collated, "passage": d_collated}


@dataclass
class DPOCollator(DataCollatorWithPadding):
def __call__(self, features) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
input_ids = [torch.tensor(ins["chosen_input_ids"]) for ins in features] + [
torch.tensor(ins["rejected_input_ids"]) for ins in features
]
labels = [torch.tensor(ins["chosen_labels"]) for ins in features] + [
torch.tensor(ins["rejected_labels"]) for ins in features
]
attention_mask = [torch.tensor(ins["chosen_attention_mask"]) for ins in features] + [
torch.tensor(ins["rejected_attention_mask"]) for ins in features
]

input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
Loading

0 comments on commit 37f3514

Please sign in to comment.