From f843cbf100a92a36d9f3f1eaf780c39295e6d0c1 Mon Sep 17 00:00:00 2001 From: kumapo Date: Sat, 26 Aug 2023 19:14:36 +0900 Subject: [PATCH] re-evaluate models with jsquad prompt with title --- lm_eval/tasks/ja/jsquad.py | 81 +++++++++++++++++++ .../llama2-2.7b/harness.jsquad-1.1-0.31.sh | 4 + .../harness.jsquad-1.1-0.21.sh | 3 + .../result.jsquad-1.1-0.21.json | 22 +++++ .../harness.jsquad-1.1-0.41.sh | 3 + .../result.jsquad-1.1-0.41.json | 22 +++++ 6 files changed, 135 insertions(+) create mode 100644 models/llama2/llama2-2.7b/harness.jsquad-1.1-0.31.sh create mode 100644 models/rinna/rinna-japanese-gpt-1b/harness.jsquad-1.1-0.21.sh create mode 100644 models/rinna/rinna-japanese-gpt-1b/result.jsquad-1.1-0.21.json create mode 100644 models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/harness.jsquad-1.1-0.41.sh create mode 100644 models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/result.jsquad-1.1-0.41.json diff --git a/lm_eval/tasks/ja/jsquad.py b/lm_eval/tasks/ja/jsquad.py index 7b643819d5..4e7957e71b 100644 --- a/lm_eval/tasks/ja/jsquad.py +++ b/lm_eval/tasks/ja/jsquad.py @@ -178,6 +178,7 @@ def _squad_agg(self, key, item): predictions, references = zip(*item) return self._squad_metric(predictions=predictions, references=references)[key] + class JSQuADWithFintanPrompt(JSQuAD): """ prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/) @@ -195,6 +196,27 @@ def doc_to_text(self, doc): + f"{self.SEP}" + "回答:" ) + + +class JSQuADWithFintanPromptHavingTitle(JSQuADWithFintanPrompt): + """ + prompt template is based on [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/) + """ + PROMPT_VERSION = 0.21 + DESCRIPTION = "質問に対する回答を題名と文章から一言で抽出してください。回答は名詞で答えてください。\n\n" + def doc_to_text(self, doc): + return ( + "題名:" + + doc["title"] + + f"{self.SEP}" + + "文章:" + + doc["context"].split("[SEP]")[-1].strip() + + f"{self.SEP}" + + "質問:" + + doc["question"] + + f"{self.SEP}" + + "回答:" + ) class JSQuADWithJAAlpacaPrompt(JSQuAD): @@ -231,6 +253,38 @@ def doc_to_text(self, doc): return f"### 指示:\n{self.INSTRUCTION}\n\n### 入力:\n{input_text}\n\n### 応答:\n" +class JSQuADWithJAAlpacaPromptHavingTitle(JSQuADWithJAAlpacaPrompt): + """ + This prompt format was inspired by the below data in fujiki/japanese_alpaca_data. + ``` + { + 'instruction': '与えられた文脈に最も適した文を選択してください。', + 'input': '文脈:あなたは親友と現在の仕事の状況について話しています。\nA)私にはあまり選択肢がありません。\nB)他に選択肢がありません。\nC)私には本当に決断する必要がありません。', + 'output': 'A) 私には多くの選択肢がありません。' + } + ``` + Reference: + - data: https://huggingface.co/datasets/fujiki/japanese_alpaca_data + - code: https://github.com/Stability-AI/gpt-neox/blob/c130a4edc1120dccec8f02a34eb60d3e8f484cd3/finetune/finetune_base_ja.py#LL118C23-L127C11 + """ + PROMPT_VERSION = 0.31 + def doc_to_text(self, doc): + """ + 以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 + + ### 指示: + {instruction} + + ### 入力: + {input} + + ### 応答: + {response} + """ + input_text = f"文脈:{doc['title']}\n{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}" + return f"### 指示:\n{self.INSTRUCTION}\n\n### 入力:\n{input_text}\n\n### 応答:\n" + + class JSQuADWithRinnaInstructionSFT(JSQuAD): """ Reference: @@ -247,6 +301,18 @@ def doc_to_text(self, doc): return f"ユーザー: {input_text}{self.SEP}システム: " +class JSQuADWithRinnaInstructionSFTHavingTitle(JSQuADWithRinnaInstructionSFT): + """ + Reference: + - HF Hub: https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft + """ + PROMPT_VERSION = 0.41 + + def doc_to_text(self, doc): + input_text = f"文脈:{doc['title']}\n{doc['context'].split('[SEP]')[-1].strip()}{self.SEP}質問:{doc['question']}" + return f"ユーザー: {input_text}{self.SEP}システム: " + + class JSQuADWithRinnaBilingualInstructionSFT(JSQuADWithRinnaInstructionSFT): """ Reference: @@ -257,13 +323,28 @@ class JSQuADWithRinnaBilingualInstructionSFT(JSQuADWithRinnaInstructionSFT): SEP = "\n" FEWSHOT_SEP = "\n" + +class JSQuADWithRinnaBilingualInstructionSFTHavingTitle(JSQuADWithRinnaInstructionSFTHavingTitle): + """ + Reference: + - HF Hub: https://huggingface.co/rinna/bilingual-gpt-neox-4b-instruction-sft + """ + PROMPT_VERSION = 0.51 + DESCRIPTION = "ユーザー: 与えられた文脈から、質問に対する答えを抜き出してください。\nシステム: 分かりました。\n" + SEP = "\n" + FEWSHOT_SEP = "\n" + VERSIONS = [ JSQuAD, JSQuADWithFintanPrompt, + JSQuADWithFintanPromptHavingTitle, JSQuADWithJAAlpacaPrompt, + JSQuADWithJAAlpacaPromptHavingTitle, JSQuADWithRinnaInstructionSFT, + JSQuADWithRinnaInstructionSFTHavingTitle, JSQuADWithRinnaBilingualInstructionSFT, + JSQuADWithRinnaBilingualInstructionSFTHavingTitle ] diff --git a/models/llama2/llama2-2.7b/harness.jsquad-1.1-0.31.sh b/models/llama2/llama2-2.7b/harness.jsquad-1.1-0.31.sh new file mode 100644 index 0000000000..fdfa73aea4 --- /dev/null +++ b/models/llama2/llama2-2.7b/harness.jsquad-1.1-0.31.sh @@ -0,0 +1,4 @@ +MODEL_ARGS="pretrained=meta-llama/Llama-2-7b-hf,use_accelerate=True,dtype=auto" +TASK="jsquad-1.1-0.31" +python main.py --model hf-causal-experimental --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2" --device "cuda" --output_path "models/llama2/llama2-2.7b/result.jsquad-1.1-0.31.json" --batch_size 2 + diff --git a/models/rinna/rinna-japanese-gpt-1b/harness.jsquad-1.1-0.21.sh b/models/rinna/rinna-japanese-gpt-1b/harness.jsquad-1.1-0.21.sh new file mode 100644 index 0000000000..32b65c0b54 --- /dev/null +++ b/models/rinna/rinna-japanese-gpt-1b/harness.jsquad-1.1-0.21.sh @@ -0,0 +1,3 @@ +MODEL_ARGS="pretrained=rinna/japanese-gpt-1b,use_fast=False" +TASK="jsquad-1.1-0.21" +python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2" --device "cuda" --output_path "models/rinna/rinna-japanese-gpt-1b/result.jsquad-1.1-0.21.json" diff --git a/models/rinna/rinna-japanese-gpt-1b/result.jsquad-1.1-0.21.json b/models/rinna/rinna-japanese-gpt-1b/result.jsquad-1.1-0.21.json new file mode 100644 index 0000000000..ce2d366360 --- /dev/null +++ b/models/rinna/rinna-japanese-gpt-1b/result.jsquad-1.1-0.21.json @@ -0,0 +1,22 @@ +{ + "results": { + "jsquad-1.1-0.21": { + "exact_match": 30.189104007203962, + "f1": 47.12467642283419 + } + }, + "versions": { + "jsquad-1.1-0.21": 1.1 + }, + "config": { + "model": "hf-causal", + "model_args": "pretrained=rinna/japanese-gpt-1b,use_fast=False", + "num_fewshot": 2, + "batch_size": null, + "device": "cuda", + "no_cache": false, + "limit": null, + "bootstrap_iters": 100000, + "description_dict": {} + } +} \ No newline at end of file diff --git a/models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/harness.jsquad-1.1-0.41.sh b/models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/harness.jsquad-1.1-0.41.sh new file mode 100644 index 0000000000..bc3d00877f --- /dev/null +++ b/models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/harness.jsquad-1.1-0.41.sh @@ -0,0 +1,3 @@ +MODEL_ARGS="pretrained=rinna/japanese-gpt-neox-3.6b-instruction-sft-v2,use_fast=False,device_map=auto,torch_dtype=auto" +TASK="jsquad-1.1-0.41" +python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2" --device "cuda" --output_path "models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/result.jsquad-1.1-0.41.json" diff --git a/models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/result.jsquad-1.1-0.41.json b/models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/result.jsquad-1.1-0.41.json new file mode 100644 index 0000000000..b7824df68b --- /dev/null +++ b/models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-sft-v2/result.jsquad-1.1-0.41.json @@ -0,0 +1,22 @@ +{ + "results": { + "jsquad-1.1-0.41": { + "exact_match": 47.90634849167042, + "f1": 62.1059309037734 + } + }, + "versions": { + "jsquad-1.1-0.41": 1.1 + }, + "config": { + "model": "hf-causal", + "model_args": "pretrained=rinna/japanese-gpt-neox-3.6b-instruction-sft-v2,use_fast=False,device_map=auto,torch_dtype=auto", + "num_fewshot": 2, + "batch_size": null, + "device": "cuda", + "no_cache": false, + "limit": null, + "bootstrap_iters": 100000, + "description_dict": {} + } +} \ No newline at end of file