Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

preferred templates #107

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"

[packages]
lm-eval = {editable = true, file = "file:///Users/kumakura/tmp/lm-evaluation-harness"}

[dev-packages]

[requires]
python_version = "3.8"
1 change: 1 addition & 0 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def simple_evaluate(
assert isinstance(model, lm_eval.base.LM)
lm = model

print(f"evaluator: no_cache={no_cache}")
if not no_cache:
lm = lm_eval.base.CachingLM(
lm,
Expand Down
4 changes: 4 additions & 0 deletions lm_eval/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ def __init__(
low_cpu_mem_usage=None,
torch_dtype=None,
device_map=None,
offload_folder=None,
subfolder=None,
tokenizer=None,
batch_size=1,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
use_fast: Optional[bool] = True,
additional_special_tokens: Optional[str] = None
):
super().__init__()

Expand Down Expand Up @@ -49,6 +51,7 @@ def __init__(
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=torch_dtype,
device_map=device_map,
offload_folder=offload_folder,
revision=revision,
trust_remote_code=trust_remote_code,
).eval()
Expand All @@ -64,6 +67,7 @@ def __init__(
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast,
additional_special_tokens=additional_special_tokens
)
self.vocab_size = self.tokenizer.vocab_size

Expand Down
81 changes: 44 additions & 37 deletions lm_eval/tasks/ja/jcommonsenseqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,23 @@ class JCommonsenseQAWithFintanPrompt(JCommonsenseQA):
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""

VERSION = 1.1
VERSION = 1.2
PROMPT_VERSION = 0.2
DESCRIPTION = (
"質問と回答の選択肢を入力として受け取り、選択肢から回答を選択してください。なお、回答は選択肢の番号(例:0)でするものとします。 \n\n"
)
DESCRIPTION = "\n\n"
SEP = "\n"
FEWSHOT_SEP = "\n\n"

# NOTE: the reason why we dont choose template for stabilityai/japanese-stablelm-base-alpha-7b is
# the template seems to include 2 typos.
def doc_to_text(self, doc):
"""
質問:question
選択肢:0.choice0,1.choice1, ...,4.choice4
回答:
"""
choices = ",".join(
[f"{idx}.{choice}" for idx, choice in enumerate(doc["choices"])]
)
return f"質問:{doc['goal']}\n" f"選択肢:{choices}\n" "回答:"

def doc_to_target(self, doc):
return f"{doc['gold']}"
以下から解答を選択してください:{{option_0}}, {{option_1}}, {{option_2}}, {{option_3}}, {{option_4}}
質問:{{question}}
回答:
"""
choices = ', '.join(doc['choices'])
return f"以下から解答を選択してください:{choices}{self.SEP}質問:{doc['goal']}{self.SEP}回答:"


class JCommonsenseQAWithJAAlpacaPrompt(JCommonsenseQA):
Expand All @@ -149,46 +147,55 @@ class JCommonsenseQAWithJAAlpacaPrompt(JCommonsenseQA):
- code: https://github.com/Stability-AI/gpt-neox/blob/c130a4edc1120dccec8f02a34eb60d3e8f484cd3/finetune/finetune_base_ja.py#LL118C23-L127C11
"""

VERSION = 1.1
VERSION = 1.2
PROMPT_VERSION = 0.3
DESCRIPTION = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n"
INSTRUCTION = "与えられた選択肢の中から、最適な答えを選んでください。"
INSTRUCTION = "正しい答えは何でしょう?"
SEP = "\n"
FEWSHOT_SEP = "\n\n"

def doc_to_text(self, doc):
"""
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。

### 指示:
{instruction}

### 入力:
{input}

### 応答:
{response}
正しい答えは何でしょう?
0.{{option_0}}
1.{{option_1}}
2.{{option_2}}
3.{{option_3}}
4.{{option_4}}
問題:{{question}}
回答:
"""
choices = "\n".join([f"- {choice}" for choice in doc["choices"]])
instruction_text = self.INSTRUCTION + f"出力は以下から選択してください:\n{choices}"
input_text = f"{doc['goal']}"
return f"### 指示:\n{instruction_text}\n\n### 入力:\n{input_text}\n\n### 応答:\n"
choices = self.SEP.join(
[f"{idx}.{choice}" for idx, choice in enumerate(doc["choices"])]
)
# instruction_text = self.INSTRUCTION + f"\n{choices}"
# input_text = f"{doc['goal']}"
# return f"### 指示:\n{instruction_text}\n\n### 入力:\n{input_text}\n\n### 応答:\n"
return f"{self.INSTRUCTION}{self.SEP}{choices}{self.SEP}問題:{doc['goal']}{self.SEP}回答:"


class JCommonsenseQAWithRinnaInstructionSFT(JCommonsenseQA):
"""
Reference:
- HF Hub: https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft
This prompt template is based on [日本語LLMベンチマークと自動プロンプトエンジニアリング](https://tech.preferred.jp/ja/blog/prompt-tuning/)
"""

VERSION = 1.1
VERSION = 1.2
PROMPT_VERSION = 0.4
DESCRIPTION = "ユーザー: 与えられた選択肢の中から、最適な答えを選んでください。<NL>システム: 分かりました。<NL>"
SEP = "<NL>"
FEWSHOT_SEP = "<NL>"
DESCRIPTION = "\n\n"
SEP = "\n\n"
FEWSHOT_SEP = "\n\n"

def doc_to_text(self, doc):
choices = self.SEP.join([f"- {choice}" for choice in doc["choices"]])
input_text = f"質問:{doc['goal']}{self.SEP}" + f"選択肢:{self.SEP}{choices}"
return f"ユーザー: {input_text}{self.SEP}システム: "
"""

以下より選択してください:{{option_0}}, {{option_1}}, {{option_2}}, {{option_3}}, {{option_4}}:

[質問]:{{question}}?
"""
choices = ', '.join(doc['choices'])
return f"以下より選択してください:{choices}{self.SEP}[質問]:{doc['goal']}"


class JCommonsenseQAWithRinnaBilingualInstructionSFT(
Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def main():
if args.description_dict_path:
with open(args.description_dict_path, "r") as f:
description_dict = json.load(f)
print(f"main: no_cache={args.no_cache}")
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=cyberagent/open-calm-3b,device_map=auto,torch_dtype=auto,load_in_8bit=True,low_cpu_mem_usage=True"
TASK="jcommonsenseqa-1.2-0.2"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/cyberagent/cyberagent-open-calm-3b/result.jcqa-1.2.json"
24 changes: 24 additions & 0 deletions models/cyberagent/cyberagent-open-calm-3b/result.jcqa-1.2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"results": {
"jcommonsenseqa-1.2-0.2": {
"acc": 0.7819481680071493,
"acc_stderr": 0.012349459533393274,
"acc_norm": 0.7184986595174263,
"acc_norm_stderr": 0.01345031058785413
}
},
"versions": {
"jcommonsenseqa-1.2-0.2": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=cyberagent/open-calm-3b,device_map=auto,torch_dtype=auto,load_in_8bit=True",
"num_fewshot": 3,
"batch_size": null,
"device": "cuda",
"no_cache": false,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=rinna/japanese-gpt-neox-3.6b-instruction-ppo,use_fast=False,device_map=auto,torch_dtype=auto"
TASK="jcommonsenseqa-1.2-0.4"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/rinna/rinna-japanese-gpt-neox-3.6b-instruction-ppo/result.jcqa-1.2.json"
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"results": {
"jcommonsenseqa-1.2-0.4": {
"acc": 0.7408400357462019,
"acc_stderr": 0.0131046454633739,
"acc_norm": 0.6863270777479893,
"acc_norm_stderr": 0.013876603522158385
}
},
"versions": {
"jcommonsenseqa-1.2-0.4": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=rinna/japanese-gpt-neox-3.6b-instruction-ppo,use_fast=False,device_map=auto,torch_dtype=auto",
"num_fewshot": 3,
"batch_size": null,
"device": "cuda",
"no_cache": false,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
set -eu
MODEL_ARGS="pretrained=stabilityai/japanese-stablelm-base-alpha-7b,use_fast=False,trust_remote_code=True,device_map=auto,torch_dtype=auto,load_in_8bit=True,offload_folder=/tmp,tokenizer=novelai/nerdstash-tokenizer-v1,additional_special_tokens=['▁▁']"
TASK="jcommonsenseqa-1.2-0.2"
NUM_FEW_SHOTS="3"
python main.py \
--model hf-causal \
--model_args $MODEL_ARGS \
--tasks $TASK \
--num_fewshot $NUM_FEW_SHOTS \
--device "cuda" \
--no_cache \
--output_path "models/stablelm/stablelm-ja-base-alpha-7b/result.jcqa-1.2.json"
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"results": {
"jcommonsenseqa-1.2-0.2": {
"acc": 0.8194816800714924,
"acc_stderr": 0.011502953501835072,
"acc_norm": 0.7864164432529044,
"acc_norm_stderr": 0.012257144279902604
}
},
"versions": {
"jcommonsenseqa-1.2-0.2": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=stabilityai/japanese-stablelm-base-alpha-7b,use_fast=False,trust_remote_code=True,device_map=auto,torch_dtype=auto,load_in_8bit=True,offload_folder=/tmp,tokenizer=novelai/nerdstash-tokenizer-v1,additional_special_tokens=['▁▁']",
"num_fewshot": 3,
"batch_size": null,
"device": "cuda",
"no_cache": true,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Loading