Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compare results between Jsquad prompt with title and without title #84

Merged
merged 8 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lm_eval/tasks/ja/jsquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class JSQuAD(Task):
"""
prompt template is taken from [日本語に特化した60億パラメータ規模のGPTモデルの構築と評価](https://www.anlp.jp/proceedings/annual_meeting/2023/pdf_dir/H9-4.pdf)
"""
VERSION = 1.1
VERSION = 1.2
PROMPT_VERSION = 0.1
DATASET_PATH = "shunk031/JGLUE"
DATASET_NAME = "JSQuAD"
Expand Down Expand Up @@ -178,24 +178,28 @@ def _squad_agg(self, key, item):
predictions, references = zip(*item)
return self._squad_metric(predictions=predictions, references=references)[key]


class JSQuADWithFintanPrompt(JSQuAD):
"""
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""
PROMPT_VERSION = 0.2
DESCRIPTION = "質問に対する回答を文章から一言で抽出してください。回答は名詞で答えてください。\n\n"
DESCRIPTION = "質問に対する回答を題名と文章から一言で抽出してください。回答は名詞で答えてください。\n\n"
SEP = "\n"
def doc_to_text(self, doc):
return (
"文章:"
"題名:"
+ doc["title"]
+ f"{self.SEP}"
+ "文章:"
+ doc["context"].split("[SEP]")[-1].strip()
+ f"{self.SEP}"
+ "質問:"
+ doc["question"]
+ f"{self.SEP}"
+ "回答:"
)


class JSQuADWithJAAlpacaPrompt(JSQuAD):
"""
Expand Down Expand Up @@ -227,7 +231,7 @@ def doc_to_text(self, doc):
### 応答:
{response}
"""
input_text = f"文脈:{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}"
input_text = f"文脈:{doc['title']}\n{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}"
return f"### 指示:\n{self.INSTRUCTION}\n\n### 入力:\n{input_text}\n\n### 応答:\n"


Expand All @@ -242,8 +246,7 @@ class JSQuADWithRinnaInstructionSFT(JSQuAD):
FEWSHOT_SEP = "<NL>"

def doc_to_text(self, doc):
input_text = f"文脈:{doc['context'].split('[SEP]')[-1].strip()}{self.SEP}質問:{doc['question']}"
# input_text = f"質問:{doc['question']}<NL>文脈:{doc['context'].split('[SEP]')[-1].strip()}"
input_text = f"文脈:{doc['title']}\n{doc['context'].split('[SEP]')[-1].strip()}{self.SEP}質問:{doc['question']}"
return f"ユーザー: {input_text}{self.SEP}システム: "


Expand Down
275 changes: 275 additions & 0 deletions lm_eval/tasks/ja/jsquad_v11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
"""
JGLUE: Japanese General Language Understanding Evaluation
https://aclanthology.org/2022.lrec-1.317/

JGLUE, Japanese General Language Understanding Evaluation, is built to measure the general NLU ability in Japanese.
JGLUE has been constructed from scratch without translation.

Homepage: https://github.com/yahoojapan/JGLUE
"""
import os
import inspect
import datasets
from math import exp
from lm_eval.base import rf, Task
from functools import partial
from lm_eval.jasquad import jasquad

_CITATION = """
@inproceedings{kurihara-etal-2022-jglue,
title = "{JGLUE}: {J}apanese General Language Understanding Evaluation",
author = "Kurihara, Kentaro and
Kawahara, Daisuke and
Shibata, Tomohide",
booktitle = "Proceedings of the Thirteenth Language Resources and Evaluation Conference",
month = jun,
year = "2022",
address = "Marseille, France",
publisher = "European Language Resources Association",
url = "https://aclanthology.org/2022.lrec-1.317",
pages = "2957--2966",
abstract = "To develop high-performance natural language understanding (NLU) models, it is necessary to have a benchmark to evaluate and analyze NLU ability from various perspectives. While the English NLU benchmark, GLUE, has been the forerunner, benchmarks are now being released for languages other than English, such as CLUE for Chinese and FLUE for French; but there is no such benchmark for Japanese. We build a Japanese NLU benchmark, JGLUE, from scratch without translation to measure the general NLU ability in Japanese. We hope that JGLUE will facilitate NLU research in Japanese.",
}
"""


DYNAMIC_MAX_LENGTH = os.getenv("DYNAMIC_MAX_LENGTH", "true").lower()


class JSQuADV11(Task):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to put inside jsquad.py and inherit from the v1 class. Then, we can see the diffs easily.

"""
prompt template is taken from [日本語に特化した60億パラメータ規模のGPTモデルの構築と評価](https://www.anlp.jp/proceedings/annual_meeting/2023/pdf_dir/H9-4.pdf)
"""
VERSION = 1.1
PROMPT_VERSION = 0.1
DATASET_PATH = "shunk031/JGLUE"
DATASET_NAME = "JSQuAD"
LOAD_TOKENIZER = True
DESCRIPTION = "[題名]と[問題]から[質問]に対する[答え]を抜き出しなさい\n\n"
SEP = "\n"
REMOVE_IDS = []
# REMOVE_IDS = ['a10743p19q0', 'a10743p19q1', 'a10743p19q2', 'a10743p19q3', 'a13221p1q0', 'a13221p1q1', 'a13221p1q2', 'a13221p1q3', 'a14985p1q0', 'a14985p1q1', 'a14985p1q2', 'a14985p1q3', 'a14985p1q4', 'a14985p93q0', 'a14985p93q1', 'a14985p93q2', 'a14985p93q3', 'a14985p93q4', 'a1540503p36q0', 'a1540503p36q1', 'a1540503p36q2', 'a1540503p36q3', 'a1540503p36q4', 'a18783p1q0', 'a18783p3q0', 'a18783p3q1', 'a18783p3q2', 'a18783p8q0', 'a18873p25q0', 'a18873p25q1', 'a18873p25q2', 'a18873p25q3', 'a18873p26q0', 'a18873p26q1', 'a18873p26q2', 'a20898p10q0', 'a20898p15q0', 'a20898p15q1', 'a20898p15q2', 'a20898p15q3', 'a2164640p22q0', 'a2164640p22q1', 'a2164640p22q2', 'a2164640p22q3', 'a2164640p22q4', 'a22392p20q0', 'a22392p20q1', 'a22392p20q2', 'a22392p20q3', 'a3011628p3q0', 'a3011628p3q1', 'a3011628p3q2', 'a3011628p3q3', 'a3189p4q0', 'a3189p4q1', 'a3189p4q2', 'a369953p0q0', 'a369953p0q1', 'a369953p0q2', 'a369953p0q3', 'a3949p1q0', 'a3949p1q1', 'a4596p0q0', 'a4596p0q1', 'a4596p0q2', 'a4596p0q3', 'a4596p1q0', 'a4596p1q1', 'a4596p1q2', 'a4596p1q3', 'a4596p1q4', 'a4596p38q0', 'a4596p38q1', 'a4596p38q2', 'a4596p38q3', 'a4596p38q4', 'a4768p13q0', 'a4768p13q1', 'a4768p13q2', 'a4768p3q0', 'a4768p3q1', 'a4768p3q2', 'a4768p3q3', 'a4768p8q0', 'a4768p8q1', 'a4768p8q2', 'a51481p0q0', 'a51481p0q1', 'a51481p0q2', 'a51481p10q0', 'a51481p10q1', 'a51481p10q2', 'a51481p10q3', 'a51481p6q0', 'a51481p6q1', 'a51481p6q2', 'a51481p6q3', 'a51481p7q0', 'a51481p7q1', 'a67892p11q0', 'a67892p11q1', 'a67892p11q2', 'a67892p11q3', 'a67892p2q0', 'a8874p6q0', 'a8874p6q1', 'a916079p3q0', 'a916079p3q1', 'a95156p4q0', 'a95156p4q1', 'a95156p4q2', 'a95156p4q3', 'a95156p6q0', 'a95156p6q1', 'a95156p6q2', 'a95156p6q3']
"""
@mkshing's comment
I found that JSQuAD contains errors inside contexts such as below.
```
{'id': 'a4596p0q0', 'title': 'ポルトガル', 'context': 'ポルトガル [SEP] 正式名称はポルトガル語で、。通称、 。', 'question': 'ポルトガルね正式名称は何語であるか', 'answers': {'text': ['正式名称はポルトガル語', 'ポルトガル語', 'ポルトガル語'], 'answer_start': [12, 17, 17]}, 'is_impossible': False}
```
So, I tried to identify all of them and found that the following processing can be okay to detect the ids
```python
from datasets import load_dataset
from transformers import T5Tokenizer
dataset = load_dataset("shunk031/JGLUE", name="JSQuAD", split="validation")
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
remove_ids = []
for item in dataset:
ctx = item["context"].split("[SEP]")[-1].strip()
input_ids = tokenizer.encode(ctx, add_special_tokens=False)
if len(input_ids) < 25:
print(item)
remove_ids.append(item["id"])
```
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.jasquad_metric = datasets.load_metric(jasquad.__file__)

def has_training_docs(self):
return True

def has_validation_docs(self):
return True

def has_test_docs(self):
return False

def training_docs(self):
return self.dataset["train"]

def validation_docs(self):
dataset = self.dataset["validation"]
if len(self.REMOVE_IDS) > 0:
dataset = [item for item in dataset if item["id"] not in self.REMOVE_IDS]
return dataset

def doc_to_text(self, doc):
return (
"[題名]:"
+ doc["title"]
+ f"{self.SEP}"
+ "[問題]:"
+ doc["context"].split("[SEP]")[-1].strip()
+ f"{self.SEP}"
+ "[質問]:"
+ doc["question"]
+ f"{self.SEP}"
+ "[答え]:"
)

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc["context"]

def doc_to_target(self, doc):
answer_list = doc["answers"]["text"]
answer = answer_list[0]
return answer

def construct_requests(self, doc, ctx):
if DYNAMIC_MAX_LENGTH == "false" or not hasattr(self.tokenizer, "encode"):
continuation = rf.greedy_until(ctx, [self.SEP])
else:
encode_fn = self.tokenizer.encode
if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
encode_params = dict(add_special_tokens=False)
else:
encode_params = {}
max_num_tokens = max([len(encode_fn(answer, **encode_params)) for answer in doc["answers"]["text"]])
continuation = rf.greedy_until(ctx, [self.SEP], max_num_tokens)
return continuation

def process_results(self, doc, results):
assert len(results) == 1, f"results should be a list with 1 str element, but is {results}"
continuation = results[0]
predictions = {
"id": doc["id"],
"prediction_text": continuation,
}

references = {
"id": doc["id"],
"answers": doc["answers"],
}
return {
"exact_match": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
}


def aggregation(self):
return {
"exact_match": partial(
self._squad_agg, "exact_match"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
self._squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
}

def higher_is_better(self):
return {
"exact_match": True, # Exact match (the normalized answer exactly match the gold answer)
"f1": True, # The F-score of predicted tokens versus the gold answer
}

def _squad_metric(self, predictions, references):
return self.jasquad_metric.compute(predictions=predictions, references=references)


def _squad_agg(self, key, item):
predictions, references = zip(*item)
return self._squad_metric(predictions=predictions, references=references)[key]


class JSQuADV11WithFintanPrompt(JSQuADV11):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to inherit from JSQuADV11 and JSQuADWithFintanPrompt to reduce duplicates?

"""
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/)
"""
PROMPT_VERSION = 0.2
DESCRIPTION = "質問に対する回答を文章から一言で抽出してください。回答は名詞で答えてください。\n\n"
SEP = "\n"
def doc_to_text(self, doc):
return (
"文章:"
+ doc["context"].split("[SEP]")[-1].strip()
+ f"{self.SEP}"
+ "質問:"
+ doc["question"]
+ f"{self.SEP}"
+ "回答:"
)


class JSQuADV11WithJAAlpacaPrompt(JSQuADV11):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as JSQuADV11WithFintanPrompt

"""
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data.
```
{
'instruction': '与えられた文脈に最も適した文を選択してください。',
'input': '文脈:あなたは親友と現在の仕事の状況について話しています。\nA)私にはあまり選択肢がありません。\nB)他に選択肢がありません。\nC)私には本当に決断する必要がありません。',
'output': 'A) 私には多くの選択肢がありません。'
}
```
Reference:
- data: https://huggingface.co/datasets/fujiki/japanese_alpaca_data
- code: https://github.com/Stability-AI/gpt-neox/blob/c130a4edc1120dccec8f02a34eb60d3e8f484cd3/finetune/finetune_base_ja.py#LL118C23-L127C11
"""
PROMPT_VERSION = 0.3
DESCRIPTION = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n"
INSTRUCTION = "与えられた文脈から、質問に対する答えを抜き出してください。"
def doc_to_text(self, doc):
"""
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。

### 指示:
{instruction}

### 入力:
{input}

### 応答:
{response}
"""
input_text = f"文脈:{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}"
return f"### 指示:\n{self.INSTRUCTION}\n\n### 入力:\n{input_text}\n\n### 応答:\n"


class JSQuADV11WithRinnaInstructionSFT(JSQuADV11):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as JSQuADV11WithFintanPrompt

"""
Reference:
- HF Hub: https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft
"""
PROMPT_VERSION = 0.4
DESCRIPTION = "ユーザー: 与えられた文脈から、質問に対する答えを抜き出してください。<NL>システム: 分かりました。<NL>"
SEP = "<NL>"
FEWSHOT_SEP = "<NL>"

def doc_to_text(self, doc):
input_text = f"文脈:{doc['context'].split('[SEP]')[-1].strip()}{self.SEP}質問:{doc['question']}"
# input_text = f"質問:{doc['question']}<NL>文脈:{doc['context'].split('[SEP]')[-1].strip()}"
return f"ユーザー: {input_text}{self.SEP}システム: "


class JSQuADV11WithRinnaBilingualInstructionSFT(JSQuADV11WithRinnaInstructionSFT):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as JSQuADV11WithFintanPrompt

"""
Reference:
- HF Hub: https://huggingface.co/rinna/bilingual-gpt-neox-4b-instruction-sft
"""
PROMPT_VERSION = 0.5
DESCRIPTION = "ユーザー: 与えられた文脈から、質問に対する答えを抜き出してください。\nシステム: 分かりました。\n"
SEP = "\n"
FEWSHOT_SEP = "\n"


VERSIONS = [
JSQuADV11,
JSQuADV11WithFintanPrompt,
JSQuADV11WithJAAlpacaPrompt,
JSQuADV11WithRinnaInstructionSFT,
JSQuADV11WithRinnaBilingualInstructionSFT,
]


def construct_tasks():
tasks = {}
for version_class in VERSIONS:
tasks[f"jsquad-{version_class.VERSION}-{version_class.PROMPT_VERSION}"] = version_class
return tasks
3 changes: 3 additions & 0 deletions models/abeja-gpt-neox-japanese-2.7b/harness.jsquad-1.2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=abeja/gpt-neox-japanese-2.7b,device_map=auto,torch_dtype=auto"
TASK="jsquad-1.2-0.2"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/abeja-gpt-neox-japanese-2.7b/result.jsquad-1.2.json"
2 changes: 1 addition & 1 deletion models/abeja-gpt-neox-japanese-2.7b/harness.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
MODEL_ARGS="pretrained=abeja/gpt-neox-japanese-2.7b"
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,xlsum_ja"
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.2-0.2,xlsum_ja"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2,3,3,3,1" --device "cuda" --output_path "models/abeja-gpt-neox-japanese-2.7b/result.json"
8 changes: 4 additions & 4 deletions models/abeja-gpt-neox-japanese-2.7b/result.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
"acc_norm": 0.749912800837112,
"acc_norm_stderr": 0.005719527388015089
},
"jsquad-1.1-0.2": {
"exact_match": 13.665015758667266,
"f1": 22.909453892411364
"jsquad-1.2-0.2": {
"exact_match": 15.803692030616839,
"f1": 25.18326978234071
},
"xlsum_ja": {
"rouge2": 6.149952794206885
Expand All @@ -33,7 +33,7 @@
"versions": {
"jcommonsenseqa-1.1-0.2": 1.1,
"jnli-1.1-0.2": 1.1,
"jsquad-1.1-0.2": 1.1,
"jsquad-1.2-0.2": 1.2,
"marc_ja-1.1-0.2": 1.1,
"xlsum_ja": 1.0,
"xwinograd_ja": 1.0
Expand Down
22 changes: 22 additions & 0 deletions models/abeja-gpt-neox-japanese-2.7b/result.jsquad-1.2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"results": {
"jsquad-1.2-0.2": {
"exact_match": 15.803692030616839,
"f1": 25.18326978234071
}
},
"versions": {
"jsquad-1.2-0.2": 1.2
},
"config": {
"model": "hf-causal",
"model_args": "pretrained=abeja/gpt-neox-japanese-2.7b,device_map=auto,torch_dtype=auto",
"num_fewshot": 3,
"batch_size": null,
"device": "cuda",
"no_cache": false,
"limit": null,
"bootstrap_iters": 100000,
"description_dict": {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=cyberagent/open-calm-1b,device_map=auto,torch_dtype=auto"
TASK="jsquad-1.2-0.2"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/cyberagent-open-calm-1b/result.jsquad-1.2.json"
Loading
Loading