Skip to content

Commit

Permalink
WIP: reproduce reported jcommonsenseqa scores with preferred templates
Browse files Browse the repository at this point in the history
  • Loading branch information
kumapo committed Oct 16, 2023
1 parent c10a2b8 commit da7be97
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 37 deletions.
81 changes: 44 additions & 37 deletions lm_eval/tasks/ja/jcommonsenseqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,23 @@ class JCommonsenseQAWithFintanPrompt(JCommonsenseQA):
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""

VERSION = 1.1
VERSION = 1.2
PROMPT_VERSION = 0.2
DESCRIPTION = (
"質問と回答の選択肢を入力として受け取り、選択肢から回答を選択してください。なお、回答は選択肢の番号(例:0)でするものとします。 \n\n"
)
DESCRIPTION = "\n\n"
SEP = "\n"
FEWSHOT_SEP = "\n\n"

# NOTE: the reason why we dont choose template for stabilityai/japanese-stablelm-base-alpha-7b is
# the template seems to include 2 typos.
def doc_to_text(self, doc):
"""
質問:question
選択肢:0.choice0,1.choice1, ...,4.choice4
回答:
"""
choices = ",".join(
[f"{idx}.{choice}" for idx, choice in enumerate(doc["choices"])]
)
return f"質問:{doc['goal']}\n" f"選択肢:{choices}\n" "回答:"
def doc_to_target(self, doc):
return f"{doc['gold']}"
以下から解答を選択してください:{{option_0}}, {{option_1}}, {{option_2}}, {{option_3}}, {{option_4}}
質問:{{question}}
回答:
"""
choices = ', '.join(doc['choices'])
return f"以下から解答を選択してください:{choices}{self.SEP}質問:{doc['goal']}{self.SEP}回答:"


class JCommonsenseQAWithJAAlpacaPrompt(JCommonsenseQA):
Expand All @@ -149,46 +147,55 @@ class JCommonsenseQAWithJAAlpacaPrompt(JCommonsenseQA):
- code: https://github.com/Stability-AI/gpt-neox/blob/c130a4edc1120dccec8f02a34eb60d3e8f484cd3/finetune/finetune_base_ja.py#LL118C23-L127C11
"""

VERSION = 1.1
VERSION = 1.2
PROMPT_VERSION = 0.3
DESCRIPTION = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n"
INSTRUCTION = "与えられた選択肢の中から、最適な答えを選んでください。"
INSTRUCTION = "正しい答えは何でしょう?"
SEP = "\n"
FEWSHOT_SEP = "\n\n"

def doc_to_text(self, doc):
"""
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
{instruction}
### 入力:
{input}
### 応答:
{response}
正しい答えは何でしょう?
0.{{option_0}}
1.{{option_1}}
2.{{option_2}}
3.{{option_3}}
4.{{option_4}}
問題:{{question}}
回答:
"""
choices = "\n".join([f"- {choice}" for choice in doc["choices"]])
instruction_text = self.INSTRUCTION + f"出力は以下から選択してください:\n{choices}"
input_text = f"{doc['goal']}"
return f"### 指示:\n{instruction_text}\n\n### 入力:\n{input_text}\n\n### 応答:\n"
choices = self.SEP.join(
[f"{idx}.{choice}" for idx, choice in enumerate(doc["choices"])]
)
# instruction_text = self.INSTRUCTION + f"\n{choices}"
# input_text = f"{doc['goal']}"
# return f"### 指示:\n{instruction_text}\n\n### 入力:\n{input_text}\n\n### 応答:\n"
return f"{self.INSTRUCTION}{self.SEP}{choices}{self.SEP}問題:{doc['goal']}{self.SEP}回答:"


class JCommonsenseQAWithRinnaInstructionSFT(JCommonsenseQA):
"""
Reference:
- HF Hub: https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft
This prompt template is based on [日本語LLMベンチマークと自動プロンプトエンジニアリング](https://tech.preferred.jp/ja/blog/prompt-tuning/)
"""

VERSION = 1.1
VERSION = 1.2
PROMPT_VERSION = 0.4
DESCRIPTION = "ユーザー: 与えられた選択肢の中から、最適な答えを選んでください。<NL>システム: 分かりました。<NL>"
SEP = "<NL>"
FEWSHOT_SEP = "<NL>"
DESCRIPTION = "\n\n"
SEP = "\n\n"
FEWSHOT_SEP = "\n\n"

def doc_to_text(self, doc):
choices = self.SEP.join([f"- {choice}" for choice in doc["choices"]])
input_text = f"質問:{doc['goal']}{self.SEP}" + f"選択肢:{self.SEP}{choices}"
return f"ユーザー: {input_text}{self.SEP}システム: "
"""
以下より選択してください:{{option_0}}, {{option_1}}, {{option_2}}, {{option_3}}, {{option_4}}:
[質問]:{{question}}?
"""
choices = ', '.join(doc['choices'])
return f"以下より選択してください:{choices}{self.SEP}[質問]:{doc['goal']}"


class JCommonsenseQAWithRinnaBilingualInstructionSFT(
Expand Down
3 changes: 3 additions & 0 deletions models/cyberagent/cyberagent-open-calm-3b/harness.jcqa-1.2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=cyberagent/open-calm-3b,device_map=auto,torch_dtype=auto,load_in_8bit=True,low_cpu_mem_usage=True"
TASK="jcommonsenseqa-1.2-0.2"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/cyberagent/cyberagent-open-calm-3b/result.jcqa-1.2.json"
24 changes: 24 additions & 0 deletions models/cyberagent/cyberagent-open-calm-3b/result.jcqa-1.2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"results": {
"jcommonsenseqa-1.2-0.2": {
"acc": 0.7819481680071493,
"acc_stderr": 0.012349459533393274,
"acc_norm": 0.7184986595174263,
"acc_norm_stderr": 0.01345031058785413
}
},
"versions": {
"jcommonsenseqa-1.2-0.2": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=cyberagent/open-calm-3b,device_map=auto,torch_dtype=auto,load_in_8bit=True",
"num_fewshot": 3,
"batch_size": null,
"device": "cuda",
"no_cache": false,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=rinna/japanese-gpt-neox-3.6b-instruction-ppo,use_fast=False,device_map=auto,torch_dtype=auto"
TASK="jcommonsenseqa-1.2-0.4"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-ppo/result.jcqa-1.2.json"
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"results": {
"jcommonsenseqa-1.2-0.4": {
"acc": 0.7408400357462019,
"acc_stderr": 0.0131046454633739,
"acc_norm": 0.6863270777479893,
"acc_norm_stderr": 0.013876603522158385
}
},
"versions": {
"jcommonsenseqa-1.2-0.4": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=rinna/japanese-gpt-neox-3.6b-instruction-ppo,use_fast=False,device_map=auto,torch_dtype=auto",
"num_fewshot": 3,
"batch_size": null,
"device": "cuda",
"no_cache": false,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
set -eu
MODEL_ARGS="pretrained=stabilityai/japanese-stablelm-base-alpha-7b,use_fast=False,trust_remote_code=True,device_map=auto,torch_dtype=auto,load_in_8bit=True,low_cpu_mem_usage=True"
TASK="jcommonsenseqa-1.2-0.2"
NUM_FEW_SHOTS="3"
python main.py \
--model hf-causal \
--model_args $MODEL_ARGS \
--tasks $TASK \
--num_fewshot $NUM_FEW_SHOTS \
--device "cuda" \
--output_path "models/stablelm/stablelm-ja-base-alpha-7b/result.jcqa-1.2.json"

0 comments on commit da7be97

Please sign in to comment.