Skip to content

Commit

Permalink
compare results between Jsquad prompt with title and without title (#84)
Browse files Browse the repository at this point in the history
* re-evaluate models with jsquad prompt with title

* update jsquad to include titles into the prompt

* re-evaluate models with jsquad prompt with title

* inherit JSQuAD v1.2 tasks from v1.1 for readability

* re-evaluate models with jsquad prompt with title

* wont need jsquad_v11

* revert result.json and harness.sh in models

* fix format
  • Loading branch information
kumapo authored Sep 30, 2023
1 parent 413593a commit bca52e7
Show file tree
Hide file tree
Showing 37 changed files with 582 additions and 26 deletions.
158 changes: 132 additions & 26 deletions lm_eval/tasks/ja/jsquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
JGLUE: Japanese General Language Understanding Evaluation
https://aclanthology.org/2022.lrec-1.317/
JGLUE, Japanese General Language Understanding Evaluation, is built to measure the general NLU ability in Japanese.
JGLUE has been constructed from scratch without translation.
JGLUE, Japanese General Language Understanding Evaluation, is built to measure the general NLU ability in Japanese.
JGLUE has been constructed from scratch without translation.
Homepage: https://github.com/yahoojapan/JGLUE
"""
Expand Down Expand Up @@ -40,6 +40,7 @@ class JSQuAD(Task):
"""
prompt template is taken from [日本語に特化した60億パラメータ規模のGPTモデルの構築と評価](https://www.anlp.jp/proceedings/annual_meeting/2023/pdf_dir/H9-4.pdf)
"""

VERSION = 1.1
PROMPT_VERSION = 0.1
DATASET_PATH = "shunk031/JGLUE"
Expand All @@ -51,11 +52,11 @@ class JSQuAD(Task):
# REMOVE_IDS = ['a10743p19q0', 'a10743p19q1', 'a10743p19q2', 'a10743p19q3', 'a13221p1q0', 'a13221p1q1', 'a13221p1q2', 'a13221p1q3', 'a14985p1q0', 'a14985p1q1', 'a14985p1q2', 'a14985p1q3', 'a14985p1q4', 'a14985p93q0', 'a14985p93q1', 'a14985p93q2', 'a14985p93q3', 'a14985p93q4', 'a1540503p36q0', 'a1540503p36q1', 'a1540503p36q2', 'a1540503p36q3', 'a1540503p36q4', 'a18783p1q0', 'a18783p3q0', 'a18783p3q1', 'a18783p3q2', 'a18783p8q0', 'a18873p25q0', 'a18873p25q1', 'a18873p25q2', 'a18873p25q3', 'a18873p26q0', 'a18873p26q1', 'a18873p26q2', 'a20898p10q0', 'a20898p15q0', 'a20898p15q1', 'a20898p15q2', 'a20898p15q3', 'a2164640p22q0', 'a2164640p22q1', 'a2164640p22q2', 'a2164640p22q3', 'a2164640p22q4', 'a22392p20q0', 'a22392p20q1', 'a22392p20q2', 'a22392p20q3', 'a3011628p3q0', 'a3011628p3q1', 'a3011628p3q2', 'a3011628p3q3', 'a3189p4q0', 'a3189p4q1', 'a3189p4q2', 'a369953p0q0', 'a369953p0q1', 'a369953p0q2', 'a369953p0q3', 'a3949p1q0', 'a3949p1q1', 'a4596p0q0', 'a4596p0q1', 'a4596p0q2', 'a4596p0q3', 'a4596p1q0', 'a4596p1q1', 'a4596p1q2', 'a4596p1q3', 'a4596p1q4', 'a4596p38q0', 'a4596p38q1', 'a4596p38q2', 'a4596p38q3', 'a4596p38q4', 'a4768p13q0', 'a4768p13q1', 'a4768p13q2', 'a4768p3q0', 'a4768p3q1', 'a4768p3q2', 'a4768p3q3', 'a4768p8q0', 'a4768p8q1', 'a4768p8q2', 'a51481p0q0', 'a51481p0q1', 'a51481p0q2', 'a51481p10q0', 'a51481p10q1', 'a51481p10q2', 'a51481p10q3', 'a51481p6q0', 'a51481p6q1', 'a51481p6q2', 'a51481p6q3', 'a51481p7q0', 'a51481p7q1', 'a67892p11q0', 'a67892p11q1', 'a67892p11q2', 'a67892p11q3', 'a67892p2q0', 'a8874p6q0', 'a8874p6q1', 'a916079p3q0', 'a916079p3q1', 'a95156p4q0', 'a95156p4q1', 'a95156p4q2', 'a95156p4q3', 'a95156p6q0', 'a95156p6q1', 'a95156p6q2', 'a95156p6q3']
"""
@mkshing's comment
I found that JSQuAD contains errors inside contexts such as below.
I found that JSQuAD contains errors inside contexts such as below.
```
{'id': 'a4596p0q0', 'title': 'ポルトガル', 'context': 'ポルトガル [SEP] 正式名称はポルトガル語で、。通称、 。', 'question': 'ポルトガルね正式名称は何語であるか', 'answers': {'text': ['正式名称はポルトガル語', 'ポルトガル語', 'ポルトガル語'], 'answer_start': [12, 17, 17]}, 'is_impossible': False}
```
So, I tried to identify all of them and found that the following processing can be okay to detect the ids
So, I tried to identify all of them and found that the following processing can be okay to detect the ids
```python
from datasets import load_dataset
from transformers import T5Tokenizer
Expand All @@ -70,19 +71,20 @@ class JSQuAD(Task):
remove_ids.append(item["id"])
```
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.jasquad_metric = datasets.load_metric(jasquad.__file__)

def has_training_docs(self):
return True

def has_validation_docs(self):
return True

def has_test_docs(self):
return False

def training_docs(self):
return self.dataset["train"]

Expand All @@ -91,7 +93,7 @@ def validation_docs(self):
if len(self.REMOVE_IDS) > 0:
dataset = [item for item in dataset if item["id"] not in self.REMOVE_IDS]
return dataset

def doc_to_text(self, doc):
return (
"[題名]:"
Expand Down Expand Up @@ -126,12 +128,19 @@ def construct_requests(self, doc, ctx):
encode_params = dict(add_special_tokens=False)
else:
encode_params = {}
max_num_tokens = max([len(encode_fn(answer, **encode_params)) for answer in doc["answers"]["text"]])
max_num_tokens = max(
[
len(encode_fn(answer, **encode_params))
for answer in doc["answers"]["text"]
]
)
continuation = rf.greedy_until(ctx, [self.SEP], max_num_tokens)
return continuation

def process_results(self, doc, results):
assert len(results) == 1, f"results should be a list with 1 str element, but is {results}"
assert (
len(results) == 1
), f"results should be a list with 1 str element, but is {results}"
continuation = results[0]
predictions = {
"id": doc["id"],
Expand All @@ -153,7 +162,6 @@ def process_results(self, doc, results):
), # The F-score of predicted tokens versus the gold answer
}


def aggregation(self):
return {
"exact_match": partial(
Expand All @@ -163,28 +171,32 @@ def aggregation(self):
self._squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
}

def higher_is_better(self):
return {
"exact_match": True, # Exact match (the normalized answer exactly match the gold answer)
"f1": True, # The F-score of predicted tokens versus the gold answer
}

def _squad_metric(self, predictions, references):
return self.jasquad_metric.compute(predictions=predictions, references=references)

return self.jasquad_metric.compute(
predictions=predictions, references=references
)

def _squad_agg(self, key, item):
predictions, references = zip(*item)
return self._squad_metric(predictions=predictions, references=references)[key]


class JSQuADWithFintanPrompt(JSQuAD):
"""
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""

PROMPT_VERSION = 0.2
DESCRIPTION = "質問に対する回答を文章から一言で抽出してください。回答は名詞で答えてください。\n\n"
SEP = "\n"

def doc_to_text(self, doc):
return (
"文章:"
Expand All @@ -195,39 +207,100 @@ def doc_to_text(self, doc):
+ f"{self.SEP}"
+ "回答:"
)



class JSQuADWithFintanPromptV12(JSQuADWithFintanPrompt):
"""
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""

VERSION = 1.2
DESCRIPTION = "質問に対する回答を題名と文章から一言で抽出してください。回答は名詞で答えてください。\n\n"

def doc_to_text(self, doc):
return (
"題名:"
+ doc["title"]
+ f"{self.SEP}"
+ "文章:"
+ doc["context"].split("[SEP]")[-1].strip()
+ f"{self.SEP}"
+ "質問:"
+ doc["question"]
+ f"{self.SEP}"
+ "回答:"
)


class JSQuADWithJAAlpacaPrompt(JSQuAD):
"""
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data.
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data.
```
{
'instruction': '与えられた文脈に最も適した文を選択してください。',
'input': '文脈:あなたは親友と現在の仕事の状況について話しています。\nA)私にはあまり選択肢がありません。\nB)他に選択肢がありません。\nC)私には本当に決断する必要がありません。',
'instruction': '与えられた文脈に最も適した文を選択してください。',
'input': '文脈:あなたは親友と現在の仕事の状況について話しています。\nA)私にはあまり選択肢がありません。\nB)他に選択肢がありません。\nC)私には本当に決断する必要がありません。',
'output': 'A) 私には多くの選択肢がありません。'
}
```
Reference:
- data: https://huggingface.co/datasets/fujiki/japanese_alpaca_data
- code: https://github.com/Stability-AI/gpt-neox/blob/c130a4edc1120dccec8f02a34eb60d3e8f484cd3/finetune/finetune_base_ja.py#LL118C23-L127C11
"""

PROMPT_VERSION = 0.3
DESCRIPTION = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n"
INSTRUCTION = "与えられた文脈から、質問に対する答えを抜き出してください。"

def doc_to_text(self, doc):
"""
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
{instruction}
### 入力:
{input}
### 応答:
{response}
"""
input_text = (
f"文脈:{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}"
)
return f"### 指示:\n{self.INSTRUCTION}\n\n### 入力:\n{input_text}\n\n### 応答:\n"


class JSQuADWithJAAlpacaPromptV12(JSQuADWithJAAlpacaPrompt):
"""
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data.
```
{
'instruction': '与えられた文脈に最も適した文を選択してください。',
'input': '文脈:あなたは親友と現在の仕事の状況について話しています。\nA)私にはあまり選択肢がありません。\nB)他に選択肢がありません。\nC)私には本当に決断する必要がありません。',
'output': 'A) 私には多くの選択肢がありません。'
}
```
Reference:
- data: https://huggingface.co/datasets/fujiki/japanese_alpaca_data
- code: https://github.com/Stability-AI/gpt-neox/blob/c130a4edc1120dccec8f02a34eb60d3e8f484cd3/finetune/finetune_base_ja.py#LL118C23-L127C11
"""

VERSION = 1.2

def doc_to_text(self, doc):
"""
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
### 指示:
{instruction}
### 入力:
### 入力:
{input}
### 応答:
### 応答:
{response}
"""
input_text = f"文脈:{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}"
input_text = f"文脈:{doc['title']}\n{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}"
return f"### 指示:\n{self.INSTRUCTION}\n\n### 入力:\n{input_text}\n\n### 応答:\n"


Expand All @@ -236,14 +309,27 @@ class JSQuADWithRinnaInstructionSFT(JSQuAD):
Reference:
- HF Hub: https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft
"""

PROMPT_VERSION = 0.4
DESCRIPTION = "ユーザー: 与えられた文脈から、質問に対する答えを抜き出してください。<NL>システム: 分かりました。<NL>"
SEP = "<NL>"
FEWSHOT_SEP = "<NL>"

def doc_to_text(self, doc):
input_text = f"文脈:{doc['context'].split('[SEP]')[-1].strip()}{self.SEP}質問:{doc['question']}"
# input_text = f"質問:{doc['question']}<NL>文脈:{doc['context'].split('[SEP]')[-1].strip()}"
return f"ユーザー: {input_text}{self.SEP}システム: "


class JSQuADWithRinnaInstructionSFTV12(JSQuADWithRinnaInstructionSFT):
"""
Reference:
- HF Hub: https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft
"""

VERSION = 1.2

def doc_to_text(self, doc):
input_text = f"文脈:{doc['title']}{self.SEP}{doc['context'].split('[SEP]')[-1].strip()}{self.SEP}質問:{doc['question']}"
return f"ユーザー: {input_text}{self.SEP}システム: "


Expand All @@ -252,23 +338,43 @@ class JSQuADWithRinnaBilingualInstructionSFT(JSQuADWithRinnaInstructionSFT):
Reference:
- HF Hub: https://huggingface.co/rinna/bilingual-gpt-neox-4b-instruction-sft
"""

PROMPT_VERSION = 0.5
DESCRIPTION = "ユーザー: 与えられた文脈から、質問に対する答えを抜き出してください。\nシステム: 分かりました。\n"
SEP = "\n"
FEWSHOT_SEP = "\n"



class JSQuADWithRinnaBilingualInstructionSFTV12(JSQuADWithRinnaBilingualInstructionSFT):
"""
Reference:
- HF Hub: https://huggingface.co/rinna/bilingual-gpt-neox-4b-instruction-sft
"""

VERSION = 1.2

def doc_to_text(self, doc):
input_text = f"文脈:{doc['title']}{self.SEP}{doc['context'].split('[SEP]')[-1].strip()}{self.SEP}質問:{doc['question']}"
return f"ユーザー: {input_text}{self.SEP}システム: "


VERSIONS = [
JSQuAD,
JSQuADWithFintanPrompt,
JSQuADWithFintanPromptV12,
JSQuADWithJAAlpacaPrompt,
JSQuADWithJAAlpacaPromptV12,
JSQuADWithRinnaInstructionSFT,
JSQuADWithRinnaInstructionSFTV12,
JSQuADWithRinnaBilingualInstructionSFT,
JSQuADWithRinnaBilingualInstructionSFTV12,
]


def construct_tasks():
tasks = {}
for version_class in VERSIONS:
tasks[f"jsquad-{version_class.VERSION}-{version_class.PROMPT_VERSION}"] = version_class
tasks[
f"jsquad-{version_class.VERSION}-{version_class.PROMPT_VERSION}"
] = version_class
return tasks
3 changes: 3 additions & 0 deletions models/abeja-gpt-neox-japanese-2.7b/harness.jsquad-1.2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=abeja/gpt-neox-japanese-2.7b,device_map=auto,torch_dtype=auto"
TASK="jsquad-1.2-0.2"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/abeja-gpt-neox-japanese-2.7b/result.jsquad-1.2.json"
22 changes: 22 additions & 0 deletions models/abeja-gpt-neox-japanese-2.7b/result.jsquad-1.2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"results": {
"jsquad-1.2-0.2": {
"exact_match": 15.803692030616839,
"f1": 25.18326978234071
}
},
"versions": {
"jsquad-1.2-0.2": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=abeja/gpt-neox-japanese-2.7b,device_map=auto,torch_dtype=auto",
"num_fewshot": 3,
"batch_size": null,
"device": "cuda",
"no_cache": false,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=cyberagent/open-calm-1b,device_map=auto,torch_dtype=auto"
TASK="jsquad-1.2-0.2"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/cyberagent-open-calm-1b/result.jsquad-1.2.json"
22 changes: 22 additions & 0 deletions models/cyberagent/cyberagent-open-calm-1b/result.jsquad-1.2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"results": {
"jsquad-1.2-0.2": {
"exact_match": 39.53174245835209,
"f1": 49.49399460234075
}
},
"versions": {
"jsquad-1.2-0.2": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=cyberagent/open-calm-1b",
"num_fewshot": 3,
"batch_size": null,
"device": "cuda",
"no_cache": false,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=cyberagent/open-calm-3b,device_map=auto,torch_dtype=auto"
TASK="jsquad-1.2-0.2"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2" --device "cuda" --output_path "models/cyberagent/cyberagent-open-calm-3b/result.jsquad-1.2.json"
22 changes: 22 additions & 0 deletions models/cyberagent/cyberagent-open-calm-3b/result.jsquad-1.2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"results": {
"jsquad-1.2-0.2": {
"exact_match": 44.529491220171096,
"f1": 56.02141036867636
}
},
"versions": {
"jsquad-1.2-0.2": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=cyberagent/open-calm-3b,device_map=auto,torch_dtype=auto",
"num_fewshot": 2,
"batch_size": null,
"device": "cuda",
"no_cache": false,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Loading

0 comments on commit bca52e7

Please sign in to comment.