From 9694e551ea46b4d137945dfbea8bfdc0761156ca Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Fri, 28 Jul 2023 15:12:51 +0800 Subject: [PATCH] support more chatbot finetuning scenarios. (#1215) --- workflows/chatbot/fine_tuning/README.md | 81 +++++- .../instruction_tuning_pipeline/data_utils.py | 262 ++++++++++++++++++ .../finetune_clm.py | 115 ++------ 3 files changed, 362 insertions(+), 96 deletions(-) create mode 100644 workflows/chatbot/fine_tuning/instruction_tuning_pipeline/data_utils.py diff --git a/workflows/chatbot/fine_tuning/README.md b/workflows/chatbot/fine_tuning/README.md index b758ca5a1ab..a9062aa1f7d 100644 --- a/workflows/chatbot/fine_tuning/README.md +++ b/workflows/chatbot/fine_tuning/README.md @@ -3,6 +3,13 @@ NeuralChat Fine-tuning This example demonstrates how to finetune the pretrained large language model (LLM) with the instruction-following dataset for creating the NeuralChat, a chatbot that can conduct the textual conversation. Giving NeuralChat the textual instruction, it will respond with the textual response. This example have been validated on the 4th Gen Intel® Xeon® Processors, Sapphire Rapids. +## Validated Model List +|Pretrained model| Text Generation (Instruction) | Text Generation (ChatBot) | summarization tuning +|------------------------------------|---|---|--- +|LLaMA series| ✅| ✅| ✅ +|MPT series|✅ |✅ |✅ +|FLAN-T5 series| ✅ | NA | NA + # Prerequisite​ ## 1. Environment​ @@ -26,11 +33,15 @@ It should be noticed that the early version of LLama model's name in Transformer The user can obtain the [release model](https://huggingface.co/google/flan-t5-xl) from Huggingface. ## 3. Prepare Dataset -The instruction-following dataset is needed for the finetuning. We select two kinds of Datasets to conduct the finetuning process: general domain dataset and domain specific dataset. +We select 4 kind of datasets to conduct the finetuning process for different tasks. + +1. Text Generation (General domain instruction): We use the [Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca) from Stanford University as the general domain dataset to fine-tune the model. This dataset is provided in the form of a JSON file, [alpaca_data.json](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json). In Alpaca, researchers have manually crafted 175 seed tasks to guide `text-davinci-003` in generating 52K instruction data for diverse tasks. -1. General domain dataset: We use the [Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca) from Stanford University as the general domain dataset to fine-tune the model. This dataset is provided in the form of a JSON file, [alpaca_data.json](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json). In Alpaca, researchers have manually crafted 175 seed tasks to guide `text-davinci-003` in generating 52K instruction data for diverse tasks. +2. Text Generation (Domain-specific instruction): Inspired by Alpaca, we constructed a domain-specific dataset focusing on Business and Intel-related issues. We made minor modifications to the [prompt template](https://github.com/tatsu-lab/stanford_alpaca/blob/main/prompt.txt) to proactively guide Alpaca in generating more Intel and Business related instruction data. The generated data could be find in `intel_domain.json`. -2. Domain-specific dataset: Inspired by Alpaca, we constructed a domain-specific dataset focusing on Business and Intel-related issues. We made minor modifications to the [prompt template](https://github.com/tatsu-lab/stanford_alpaca/blob/main/prompt.txt) to proactively guide Alpaca in generating more Intel and Business related instruction data. The generated data could be find in `intel_domain.json`. +3. Text Generation (ChatBot): To finetune a chatbot, we use the chat-style dataset [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1). + +4. Summarization: An English-language dataset [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) containing just over 300k unique news articles as written by journalists at CNN and the Daily Mail, is used for this task. # Finetune @@ -38,7 +49,7 @@ We employ the [LoRA approach](https://arxiv.org/pdf/2106.09685.pdf) to finetune ## 1. Single Node Fine-tuning in Xeon SPR -For FLAN-T5, use the below command line for finetuning on the Alpaca dataset. +**For FLAN-T5**, use the below command line for finetuning on the Alpaca dataset. ```bash python finetune_seq2seq.py \ @@ -61,7 +72,9 @@ python finetune_seq2seq.py \ --peft lora ``` -For LLaMA, use the below command line for finetuning on the Alpaca dataset. +#### For LLaMA + +- use the below command line for finetuning on the Alpaca dataset. ```bash python finetune_clm.py \ @@ -86,7 +99,61 @@ python finetune_clm.py \ --no_cuda \ ``` -For [MPT](https://huggingface.co/mosaicml/mpt-7b), use the below command line for finetuning on the Alpaca dataset. Only LORA supports MPT in PEFT perspective.it uses gpt-neox-20b tokenizer, so you need to define it in command line explicitly.This model also requires that trust_remote_code=True be passed to the from_pretrained method. This is because we use a custom MPT model architecture that is not yet part of the Hugging Face transformers package. +- use the below command line for finetuning chatbot on the [Intel/openassistant-preprocessed](https://huggingface.co/datasets/Intel/openassistant-preprocessed). + +```bash +python finetune_clm.py \ + --model_name_or_path "decapoda-research/llama-7b-hf" \ + --bf16 True \ + --dataset_name "Intel/openassistant-preprocessed" \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --gradient_accumulation_steps 1 \ + --do_train \ + --learning_rate 1e-4 \ + --num_train_epochs 3 \ + --logging_steps 100 \ + --save_total_limit 2 \ + --overwrite_output_dir \ + --log_level info \ + --save_strategy epoch \ + --output_dir ./llama_chatbot_peft_finetuned_model \ + --peft lora \ + --use_fast_tokenizer false \ + --no_cuda \ + --special_tokens "<|im_start|>" "<|im_end|>" + +# the script also support other models, like mpt. +``` + +- use the below command line for summarization scenario on the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail). + +```bash +python finetune_clm.py \ + --model_name_or_path "/models/llama-7b-hf" \ + --bf16 True \ + --dataset_name "cnn_dailymail" \ + --dataset_config_name "3.0.0" \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --gradient_accumulation_steps 1 \ + --do_train \ + --learning_rate 1e-4 \ + --num_train_epochs 3 \ + --logging_steps 100 \ + --save_total_limit 2 \ + --overwrite_output_dir \ + --log_level info \ + --save_strategy epoch \ + --output_dir ./llama_peft_finetuned_model \ + --peft lora \ + --use_fast_tokenizer false \ + --no_cuda + +# the script also support other models, like mpt. +``` + +**For [MPT](https://huggingface.co/mosaicml/mpt-7b)**, use the below command line for finetuning on the Alpaca dataset. Only LORA supports MPT in PEFT perspective.it uses gpt-neox-20b tokenizer, so you need to define it in command line explicitly.This model also requires that trust_remote_code=True be passed to the from_pretrained method. This is because we use a custom MPT model architecture that is not yet part of the Hugging Face transformers package. ```bash python finetune_clm.py \ @@ -382,4 +449,4 @@ For finetuning on SPR, add `--bf16` argument will speedup the finetuning process You could also indicate `--peft` to switch peft method in P-tuning, Prefix tuning, Prompt tuning, LLama Adapter, LoRA, see https://github.com/huggingface/peft. Note for MPT, only LoRA is supported. -Add option **"--use_fast_tokenizer False"** when using latest transformers if you met failure in llama fast tokenizer for llama, The `tokenizer_class` in `tokenizer_config.json` should be changed from `LLaMATokenizer` to `LlamaTokenizer` \ No newline at end of file +Add option **"--use_fast_tokenizer False"** when using latest transformers if you met failure in llama fast tokenizer for llama, The `tokenizer_class` in `tokenizer_config.json` should be changed from `LLaMATokenizer` to `LlamaTokenizer` diff --git a/workflows/chatbot/fine_tuning/instruction_tuning_pipeline/data_utils.py b/workflows/chatbot/fine_tuning/instruction_tuning_pipeline/data_utils.py new file mode 100644 index 00000000000..b88f101976e --- /dev/null +++ b/workflows/chatbot/fine_tuning/instruction_tuning_pipeline/data_utils.py @@ -0,0 +1,262 @@ +import copy +import datasets +import re +from itertools import chain + +IGNORE_INDEX = -100 + +ALPACA_PROMPT_DICT = { + "prompt_with_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_without_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), +} + +conv_header = """<|im_start|>system +- You are a helpful assistant chatbot trained by Intel. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n""" + +user = "<|im_start|>user\n" +assistant = "<|im_start|>assistant\n" +end = "<|im_end|>" + +summarization_suffix_template = "\nSummarize the highlights of this article.\n" + +def create_alpaca(examples): + prompts = {} + prompts["source"] = [] + prompts["target"] = [] + for example in examples: + prompt_template = ( + ALPACA_PROMPT_DICT["prompt_with_input"] + if example["input"] != "" + else ALPACA_PROMPT_DICT["prompt_without_input"] + ) + source = prompt_template.format_map(example) + prompts["source"].append(source) + prompts["target"].append(example["output"]) + return prompts + + +def tokenize_alpaca(tokenizer, data_args, finetune_args): + def tokenize(prompt, add_eos_token=True): + results = tokenizer( + prompt, + truncation=True, + max_length=data_args.max_seq_length, + padding=False, + return_tensors=None,) + for i in range(len(results["input_ids"])): + if (results["input_ids"][i][-1] != tokenizer.eos_token_id \ + and len(results["input_ids"][i]) < data_args.max_seq_length \ + and add_eos_token \ + ): + results["input_ids"][i].append(tokenizer.eos_token_id) + results["attention_mask"][i].append(1) + results["labels"] = copy.deepcopy(results["input_ids"]) + results["input_id_len"] = [len(result) for result in results["input_ids"]] + return results + + def preprocess_function(examples): + st = [s + t for s, t in zip(examples["prompt_sources"], examples["prompt_targets"])] + examples_tokenized = tokenize(st) + input_ids = examples_tokenized["input_ids"] + labels = examples_tokenized["labels"] + if not finetune_args.train_on_inputs: + sources_tokenized = tokenize(examples["prompt_sources"], add_eos_token=False) + for label, source_len in zip(labels, sources_tokenized["input_id_len"]): + label[:source_len] = [IGNORE_INDEX] * source_len + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=examples_tokenized["attention_mask"], + ) + + return preprocess_function + + +def create_oasst(examples): + prompts = {} + prompts["prompt_sources"] = [] + prompts["prompt_targets"] = [] + + for conv in examples: + conv = conv["messages"] + prompt = conv_header + + for j in range(0, len(conv) - 1, 2): + u = conv[j]["content"] + ass = conv[j+1]["content"] + prompt = prompt + user + u + end + '\n' + assistant + response = ass + end + prompts["prompt_sources"].append(prompt) + prompts["prompt_targets"].append(response) + + prompt += response + '\n' + return prompts + +def truncate_sequences(sequences, max_length): + words_to_cut = sum(list(map(len, sequences))) - max_length + if words_to_cut <= 0: + return sequences + + while words_to_cut > 0 and len(sequences) > 0: + words_to_cut -= len(sequences[0]) + sequences = sequences[1:] + + return sequences + +def tokenize_oasst(tokenizer, data_args, finetune_args): + + # special tokens + assistant_tokens = tokenizer.tokenize(assistant) + + def preprocess_function(examples): + + instructions = [q.strip() for q in examples["prompt_sources"]] + responses = [q.strip() for q in examples["prompt_targets"]] + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + + for instruction, response in zip(instructions, responses): + header = re.findall("\<\|im_start\|\>system.*?\<\|im_end\|\>", instruction, re.DOTALL)[0] + convs = re.findall("\<\|im_start\|\>.*?\<\|im_end\|\>", instruction, re.DOTALL)[1:] + + convs_tokens = [ + tokenizer.tokenize(conv) + tokenizer.tokenize("\n") + for conv in convs + ] + header_tokens = tokenizer.tokenize(header) + tokenizer.tokenize("\n") + + max_input = data_args.max_source_length - len(header_tokens) - len(assistant_tokens) + + truncated_convs = truncate_sequences(convs_tokens, + max_input) + + if len(truncated_convs) == 0: + truncated_convs = [convs_tokens[-1][:max_input - 1] + convs_tokens[-1][-1:]] + + prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens] + prompt_ids = [tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens] + prompt_ids = list(chain(*prompt_ids)) + + resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(response.strip())) + # keep last and eos_id + max_resp = data_args.max_seq_length - len(prompt_ids) - 1 + if len(resp_ids) > max_resp: + resp_ids = resp_ids[:max_resp - 1] + resp_ids[-1:] + + input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id] + if not finetune_args.train_on_inputs: + labels = [-100] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id] + else: + labels = prompt_ids + resp_ids + [tokenizer.eos_token_id] + + # padding + input_len = len(input_ids) + pad_len = data_args.max_seq_length - input_len + input_ids = input_ids + [tokenizer.eos_token_id] * pad_len + labels = labels + [-100] * pad_len + attention_mask = [1] * input_len + [0] * pad_len + + assert len(input_ids) == data_args.max_seq_length + assert len(prompt_ids) <= data_args.max_source_length + assert len(labels) == len(input_ids) == len(attention_mask) + + examples["input_ids"].append(input_ids) + examples["labels"].append(labels) + examples["attention_mask"].append(attention_mask) + + return examples + + return preprocess_function + +def tokenize_cnn(tokenizer, data_args, finetune_args): + template_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summarization_suffix_template)) + + def preprocess_function(examples): + + articles = [q.strip() for q in examples["article"]] + highlights = [q.strip() for q in examples["highlights"]] + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + + for article, highlight in zip(articles, highlights): + max_input = data_args.max_source_length - len(template_ids) + + article_tokens = tokenizer.tokenize(article)[:max_input] + prompt_ids = tokenizer.convert_tokens_to_ids(article_tokens) + template_ids + + max_resp = data_args.max_seq_length - len(prompt_ids) - 1 + resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(highlight))[:max_resp] + + input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id] + if not finetune_args.train_on_inputs: + labels = [-100] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id] + else: + labels = prompt_ids + resp_ids + [tokenizer.eos_token_id] + + # padding + input_len = len(input_ids) + pad_len = data_args.max_seq_length - input_len + input_ids = input_ids + [tokenizer.eos_token_id] * pad_len + labels = labels + [-100] * pad_len + attention_mask = [1] * input_len + [0] * pad_len + + assert len(input_ids) == data_args.max_seq_length + assert len(prompt_ids) <= data_args.max_source_length + assert len(labels) == len(input_ids) == len(attention_mask) + + examples["input_ids"].append(input_ids) + examples["labels"].append(labels) + examples["attention_mask"].append(attention_mask) + + return examples + + return preprocess_function + + +def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args): + + dataset_name = data_args.dataset_name if data_args.dataset_name is not None else data_args.train_file + if "oasst" in dataset_name: + new_datasets = datasets.DatasetDict() + for key in ["train"]: + prompts = create_oasst(raw_datasets[key]) + new_datasets[key] = datasets.Dataset.from_dict(prompts) + + preprocess_fn = tokenize_oasst(tokenizer, data_args, finetune_args) + + return new_datasets, preprocess_fn + + elif "cnn" in dataset_name: + preprocess_fn = tokenize_cnn(tokenizer, data_args, finetune_args) + return raw_datasets, preprocess_fn + else: + # default use alpaca instruction template + for key in raw_datasets: + prompts = create_alpaca(raw_datasets[key]) + columns_to_be_removed = list(raw_datasets[key].features.keys()) + raw_datasets[key] = raw_datasets[key].add_column( + "prompt_sources", prompts["source"] + ) + raw_datasets[key] = raw_datasets[key].add_column( + "prompt_targets", prompts["target"] + ) + raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed) + + preprocess_fn = tokenize_alpaca(tokenizer, data_args, finetune_args) + + return raw_datasets, preprocess_fn diff --git a/workflows/chatbot/fine_tuning/instruction_tuning_pipeline/finetune_clm.py b/workflows/chatbot/fine_tuning/instruction_tuning_pipeline/finetune_clm.py index 7fb3529b0a1..d1b3e953140 100644 --- a/workflows/chatbot/fine_tuning/instruction_tuning_pipeline/finetune_clm.py +++ b/workflows/chatbot/fine_tuning/instruction_tuning_pipeline/finetune_clm.py @@ -47,13 +47,11 @@ ) from transformers.trainer_utils import is_main_process from typing import Optional, List -import copy import re import torch import importlib.util from transformers.utils.import_utils import is_optimum_available - -IGNORE_INDEX = -100 +from data_utils import preprocess_dataset os.environ["WANDB_DISABLED"] = "true" @@ -209,6 +207,17 @@ class DataArguments: "help": "Whether to concatenate the sentence for more efficient training." }, ) + special_tokens: Optional[List[str]] = field( + default=None, + metadata={"help": "The list of special tokens to add in tokenizer."} + ) + max_source_length: Optional[int] = field( + default=384, + metadata={ + "help": "The maximum total source sequence length after tokenization. Sequences longer " + "than this will be truncated." + }, + ) @dataclass @@ -261,7 +270,7 @@ class FinetuneArguments: }, ) train_on_inputs: bool = field( - default=True, + default=False, metadata={"help": "if False, masks out inputs in loss"}, ) habana: bool = field( @@ -270,36 +279,6 @@ class FinetuneArguments: ) -PROMPT_DICT = { - "prompt_with_input": ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" - ), - "prompt_without_input": ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:" - ), -} - - -def create_prompts(examples): - prompts = {} - prompts["source"] = [] - prompts["target"] = [] - for example in examples: - prompt_template = ( - PROMPT_DICT["prompt_with_input"] - if example["input"] != "" - else PROMPT_DICT["prompt_without_input"] - ) - source = prompt_template.format_map(example) - prompts["source"].append(source) - prompts["target"].append(example["output"]) - return prompts - - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -415,7 +394,6 @@ def main(): data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, - streaming=data_args.streaming, ) if "validation" not in raw_datasets.keys() and training_args.do_eval: @@ -425,7 +403,6 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, - streaming=data_args.streaming, ) raw_datasets["train"] = load_dataset( data_args.dataset_name, @@ -433,7 +410,6 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, - streaming=data_args.streaming, ) else: data_files = {} @@ -477,18 +453,6 @@ def main(): **dataset_args, ) - # Preprocessing the datasets. - for key in raw_datasets: - prompts = create_prompts(raw_datasets[key]) - columns_to_be_removed = list(raw_datasets[key].features.keys()) - raw_datasets[key] = raw_datasets[key].add_column( - "prompt_sources", prompts["source"] - ) - raw_datasets[key] = raw_datasets[key].add_column( - "prompt_targets", prompts["target"] - ) - raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed) - # Load model if model_args.model_name_or_path: model_dtype = torch.bfloat16 if training_args.bf16 else None @@ -523,6 +487,18 @@ def main(): "Must provide model_name_or_path to load a pretrained CausalLM model." ) + # add special tokens + if data_args.special_tokens: + additional_special_tokens = { + "additional_special_tokens": data_args.special_tokens} + tokenizer.add_special_tokens(additional_special_tokens) + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + if re.search("llama", model.config.architectures[0], re.IGNORECASE): # unwind broken decapoda-research config model.generation_config.pad_token_id = 0 @@ -549,46 +525,7 @@ def main(): tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" # Allow batched inference - def tokenize(prompt, add_eos_token=True): - results = tokenizer( - prompt, - truncation=True, - max_length=data_args.max_seq_length, - padding=False, - return_tensors=None, - ) - for i in range(len(results["input_ids"])): - if ( - results["input_ids"][i][-1] != tokenizer.eos_token_id - and len(results["input_ids"][i]) < data_args.max_seq_length - and add_eos_token - ): - results["input_ids"][i].append(tokenizer.eos_token_id) - results["attention_mask"][i].append(1) - - results["labels"] = copy.deepcopy(results["input_ids"]) - results["input_id_len"] = [len(result) for result in results["input_ids"]] - return results - - def preprocess_function(examples): - st = [ - s + t - for s, t in zip(examples["prompt_sources"], examples["prompt_targets"]) - ] - examples_tokenized = tokenize(st) - input_ids = examples_tokenized["input_ids"] - labels = examples_tokenized["labels"] - if not finetune_args.train_on_inputs: - sources_tokenized = tokenize( - examples["prompt_sources"], add_eos_token=False - ) - for label, source_len in zip(labels, sources_tokenized["input_id_len"]): - label[:source_len] = [IGNORE_INDEX] * source_len - return dict( - input_ids=input_ids, - labels=labels, - attention_mask=examples_tokenized["attention_mask"], - ) + raw_datasets, preprocess_function = preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args) with training_args.main_process_first(desc="dataset map pre-processing"): tokenized_datasets = raw_datasets.map(