-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compare results between Jsquad prompt with title and without title #84
Changes from 3 commits
8fa63b7
11d8c89
7e9e1ab
e920f6d
6b2fb51
e1444bf
359b436
bfb439e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,275 @@ | ||
""" | ||
JGLUE: Japanese General Language Understanding Evaluation | ||
https://aclanthology.org/2022.lrec-1.317/ | ||
|
||
JGLUE, Japanese General Language Understanding Evaluation, is built to measure the general NLU ability in Japanese. | ||
JGLUE has been constructed from scratch without translation. | ||
|
||
Homepage: https://github.com/yahoojapan/JGLUE | ||
""" | ||
import os | ||
import inspect | ||
import datasets | ||
from math import exp | ||
from lm_eval.base import rf, Task | ||
from functools import partial | ||
from lm_eval.jasquad import jasquad | ||
|
||
_CITATION = """ | ||
@inproceedings{kurihara-etal-2022-jglue, | ||
title = "{JGLUE}: {J}apanese General Language Understanding Evaluation", | ||
author = "Kurihara, Kentaro and | ||
Kawahara, Daisuke and | ||
Shibata, Tomohide", | ||
booktitle = "Proceedings of the Thirteenth Language Resources and Evaluation Conference", | ||
month = jun, | ||
year = "2022", | ||
address = "Marseille, France", | ||
publisher = "European Language Resources Association", | ||
url = "https://aclanthology.org/2022.lrec-1.317", | ||
pages = "2957--2966", | ||
abstract = "To develop high-performance natural language understanding (NLU) models, it is necessary to have a benchmark to evaluate and analyze NLU ability from various perspectives. While the English NLU benchmark, GLUE, has been the forerunner, benchmarks are now being released for languages other than English, such as CLUE for Chinese and FLUE for French; but there is no such benchmark for Japanese. We build a Japanese NLU benchmark, JGLUE, from scratch without translation to measure the general NLU ability in Japanese. We hope that JGLUE will facilitate NLU research in Japanese.", | ||
} | ||
""" | ||
|
||
|
||
DYNAMIC_MAX_LENGTH = os.getenv("DYNAMIC_MAX_LENGTH", "true").lower() | ||
|
||
|
||
class JSQuADV11(Task): | ||
""" | ||
prompt template is taken from [日本語に特化した60億パラメータ規模のGPTモデルの構築と評価](https://www.anlp.jp/proceedings/annual_meeting/2023/pdf_dir/H9-4.pdf) | ||
""" | ||
VERSION = 1.1 | ||
PROMPT_VERSION = 0.1 | ||
DATASET_PATH = "shunk031/JGLUE" | ||
DATASET_NAME = "JSQuAD" | ||
LOAD_TOKENIZER = True | ||
DESCRIPTION = "[題名]と[問題]から[質問]に対する[答え]を抜き出しなさい\n\n" | ||
SEP = "\n" | ||
REMOVE_IDS = [] | ||
# REMOVE_IDS = ['a10743p19q0', 'a10743p19q1', 'a10743p19q2', 'a10743p19q3', 'a13221p1q0', 'a13221p1q1', 'a13221p1q2', 'a13221p1q3', 'a14985p1q0', 'a14985p1q1', 'a14985p1q2', 'a14985p1q3', 'a14985p1q4', 'a14985p93q0', 'a14985p93q1', 'a14985p93q2', 'a14985p93q3', 'a14985p93q4', 'a1540503p36q0', 'a1540503p36q1', 'a1540503p36q2', 'a1540503p36q3', 'a1540503p36q4', 'a18783p1q0', 'a18783p3q0', 'a18783p3q1', 'a18783p3q2', 'a18783p8q0', 'a18873p25q0', 'a18873p25q1', 'a18873p25q2', 'a18873p25q3', 'a18873p26q0', 'a18873p26q1', 'a18873p26q2', 'a20898p10q0', 'a20898p15q0', 'a20898p15q1', 'a20898p15q2', 'a20898p15q3', 'a2164640p22q0', 'a2164640p22q1', 'a2164640p22q2', 'a2164640p22q3', 'a2164640p22q4', 'a22392p20q0', 'a22392p20q1', 'a22392p20q2', 'a22392p20q3', 'a3011628p3q0', 'a3011628p3q1', 'a3011628p3q2', 'a3011628p3q3', 'a3189p4q0', 'a3189p4q1', 'a3189p4q2', 'a369953p0q0', 'a369953p0q1', 'a369953p0q2', 'a369953p0q3', 'a3949p1q0', 'a3949p1q1', 'a4596p0q0', 'a4596p0q1', 'a4596p0q2', 'a4596p0q3', 'a4596p1q0', 'a4596p1q1', 'a4596p1q2', 'a4596p1q3', 'a4596p1q4', 'a4596p38q0', 'a4596p38q1', 'a4596p38q2', 'a4596p38q3', 'a4596p38q4', 'a4768p13q0', 'a4768p13q1', 'a4768p13q2', 'a4768p3q0', 'a4768p3q1', 'a4768p3q2', 'a4768p3q3', 'a4768p8q0', 'a4768p8q1', 'a4768p8q2', 'a51481p0q0', 'a51481p0q1', 'a51481p0q2', 'a51481p10q0', 'a51481p10q1', 'a51481p10q2', 'a51481p10q3', 'a51481p6q0', 'a51481p6q1', 'a51481p6q2', 'a51481p6q3', 'a51481p7q0', 'a51481p7q1', 'a67892p11q0', 'a67892p11q1', 'a67892p11q2', 'a67892p11q3', 'a67892p2q0', 'a8874p6q0', 'a8874p6q1', 'a916079p3q0', 'a916079p3q1', 'a95156p4q0', 'a95156p4q1', 'a95156p4q2', 'a95156p4q3', 'a95156p6q0', 'a95156p6q1', 'a95156p6q2', 'a95156p6q3'] | ||
""" | ||
@mkshing's comment | ||
I found that JSQuAD contains errors inside contexts such as below. | ||
``` | ||
{'id': 'a4596p0q0', 'title': 'ポルトガル', 'context': 'ポルトガル [SEP] 正式名称はポルトガル語で、。通称、 。', 'question': 'ポルトガルね正式名称は何語であるか', 'answers': {'text': ['正式名称はポルトガル語', 'ポルトガル語', 'ポルトガル語'], 'answer_start': [12, 17, 17]}, 'is_impossible': False} | ||
``` | ||
So, I tried to identify all of them and found that the following processing can be okay to detect the ids | ||
```python | ||
from datasets import load_dataset | ||
from transformers import T5Tokenizer | ||
dataset = load_dataset("shunk031/JGLUE", name="JSQuAD", split="validation") | ||
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b") | ||
remove_ids = [] | ||
for item in dataset: | ||
ctx = item["context"].split("[SEP]")[-1].strip() | ||
input_ids = tokenizer.encode(ctx, add_special_tokens=False) | ||
if len(input_ids) < 25: | ||
print(item) | ||
remove_ids.append(item["id"]) | ||
``` | ||
""" | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.jasquad_metric = datasets.load_metric(jasquad.__file__) | ||
|
||
def has_training_docs(self): | ||
return True | ||
|
||
def has_validation_docs(self): | ||
return True | ||
|
||
def has_test_docs(self): | ||
return False | ||
|
||
def training_docs(self): | ||
return self.dataset["train"] | ||
|
||
def validation_docs(self): | ||
dataset = self.dataset["validation"] | ||
if len(self.REMOVE_IDS) > 0: | ||
dataset = [item for item in dataset if item["id"] not in self.REMOVE_IDS] | ||
return dataset | ||
|
||
def doc_to_text(self, doc): | ||
return ( | ||
"[題名]:" | ||
+ doc["title"] | ||
+ f"{self.SEP}" | ||
+ "[問題]:" | ||
+ doc["context"].split("[SEP]")[-1].strip() | ||
+ f"{self.SEP}" | ||
+ "[質問]:" | ||
+ doc["question"] | ||
+ f"{self.SEP}" | ||
+ "[答え]:" | ||
) | ||
|
||
def should_decontaminate(self): | ||
return True | ||
|
||
def doc_to_decontamination_query(self, doc): | ||
return doc["context"] | ||
|
||
def doc_to_target(self, doc): | ||
answer_list = doc["answers"]["text"] | ||
answer = answer_list[0] | ||
return answer | ||
|
||
def construct_requests(self, doc, ctx): | ||
if DYNAMIC_MAX_LENGTH == "false" or not hasattr(self.tokenizer, "encode"): | ||
continuation = rf.greedy_until(ctx, [self.SEP]) | ||
else: | ||
encode_fn = self.tokenizer.encode | ||
if "add_special_tokens" in inspect.getfullargspec(encode_fn).args: | ||
encode_params = dict(add_special_tokens=False) | ||
else: | ||
encode_params = {} | ||
max_num_tokens = max([len(encode_fn(answer, **encode_params)) for answer in doc["answers"]["text"]]) | ||
continuation = rf.greedy_until(ctx, [self.SEP], max_num_tokens) | ||
return continuation | ||
|
||
def process_results(self, doc, results): | ||
assert len(results) == 1, f"results should be a list with 1 str element, but is {results}" | ||
continuation = results[0] | ||
predictions = { | ||
"id": doc["id"], | ||
"prediction_text": continuation, | ||
} | ||
|
||
references = { | ||
"id": doc["id"], | ||
"answers": doc["answers"], | ||
} | ||
return { | ||
"exact_match": ( | ||
predictions, | ||
references, | ||
), # Exact match (the normalized answer exactly match the gold answer) | ||
"f1": ( | ||
predictions, | ||
references, | ||
), # The F-score of predicted tokens versus the gold answer | ||
} | ||
|
||
|
||
def aggregation(self): | ||
return { | ||
"exact_match": partial( | ||
self._squad_agg, "exact_match" | ||
), # Exact match (the normalized answer exactly match the gold answer) | ||
"f1": partial( | ||
self._squad_agg, "f1" | ||
), # The F-score of predicted tokens versus the gold answer | ||
} | ||
|
||
def higher_is_better(self): | ||
return { | ||
"exact_match": True, # Exact match (the normalized answer exactly match the gold answer) | ||
"f1": True, # The F-score of predicted tokens versus the gold answer | ||
} | ||
|
||
def _squad_metric(self, predictions, references): | ||
return self.jasquad_metric.compute(predictions=predictions, references=references) | ||
|
||
|
||
def _squad_agg(self, key, item): | ||
predictions, references = zip(*item) | ||
return self._squad_metric(predictions=predictions, references=references)[key] | ||
|
||
|
||
class JSQuADV11WithFintanPrompt(JSQuADV11): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to inherit from |
||
""" | ||
prompt template is taken from [ChatGPT vs BERT: どちらが日本語をより理解できるのか?](https://fintan.jp/page/9126/) | ||
""" | ||
PROMPT_VERSION = 0.2 | ||
DESCRIPTION = "質問に対する回答を文章から一言で抽出してください。回答は名詞で答えてください。\n\n" | ||
SEP = "\n" | ||
def doc_to_text(self, doc): | ||
return ( | ||
"文章:" | ||
+ doc["context"].split("[SEP]")[-1].strip() | ||
+ f"{self.SEP}" | ||
+ "質問:" | ||
+ doc["question"] | ||
+ f"{self.SEP}" | ||
+ "回答:" | ||
) | ||
|
||
|
||
class JSQuADV11WithJAAlpacaPrompt(JSQuADV11): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as |
||
""" | ||
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data. | ||
``` | ||
{ | ||
'instruction': '与えられた文脈に最も適した文を選択してください。', | ||
'input': '文脈:あなたは親友と現在の仕事の状況について話しています。\nA)私にはあまり選択肢がありません。\nB)他に選択肢がありません。\nC)私には本当に決断する必要がありません。', | ||
'output': 'A) 私には多くの選択肢がありません。' | ||
} | ||
``` | ||
Reference: | ||
- data: https://huggingface.co/datasets/fujiki/japanese_alpaca_data | ||
- code: https://github.com/Stability-AI/gpt-neox/blob/c130a4edc1120dccec8f02a34eb60d3e8f484cd3/finetune/finetune_base_ja.py#LL118C23-L127C11 | ||
""" | ||
PROMPT_VERSION = 0.3 | ||
DESCRIPTION = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n\n" | ||
INSTRUCTION = "与えられた文脈から、質問に対する答えを抜き出してください。" | ||
def doc_to_text(self, doc): | ||
""" | ||
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 | ||
|
||
### 指示: | ||
{instruction} | ||
|
||
### 入力: | ||
{input} | ||
|
||
### 応答: | ||
{response} | ||
""" | ||
input_text = f"文脈:{doc['context'].split('[SEP]')[-1].strip()}\n質問:{doc['question']}" | ||
return f"### 指示:\n{self.INSTRUCTION}\n\n### 入力:\n{input_text}\n\n### 応答:\n" | ||
|
||
|
||
class JSQuADV11WithRinnaInstructionSFT(JSQuADV11): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as |
||
""" | ||
Reference: | ||
- HF Hub: https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft | ||
""" | ||
PROMPT_VERSION = 0.4 | ||
DESCRIPTION = "ユーザー: 与えられた文脈から、質問に対する答えを抜き出してください。<NL>システム: 分かりました。<NL>" | ||
SEP = "<NL>" | ||
FEWSHOT_SEP = "<NL>" | ||
|
||
def doc_to_text(self, doc): | ||
input_text = f"文脈:{doc['context'].split('[SEP]')[-1].strip()}{self.SEP}質問:{doc['question']}" | ||
# input_text = f"質問:{doc['question']}<NL>文脈:{doc['context'].split('[SEP]')[-1].strip()}" | ||
return f"ユーザー: {input_text}{self.SEP}システム: " | ||
|
||
|
||
class JSQuADV11WithRinnaBilingualInstructionSFT(JSQuADV11WithRinnaInstructionSFT): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as |
||
""" | ||
Reference: | ||
- HF Hub: https://huggingface.co/rinna/bilingual-gpt-neox-4b-instruction-sft | ||
""" | ||
PROMPT_VERSION = 0.5 | ||
DESCRIPTION = "ユーザー: 与えられた文脈から、質問に対する答えを抜き出してください。\nシステム: 分かりました。\n" | ||
SEP = "\n" | ||
FEWSHOT_SEP = "\n" | ||
|
||
|
||
VERSIONS = [ | ||
JSQuADV11, | ||
JSQuADV11WithFintanPrompt, | ||
JSQuADV11WithJAAlpacaPrompt, | ||
JSQuADV11WithRinnaInstructionSFT, | ||
JSQuADV11WithRinnaBilingualInstructionSFT, | ||
] | ||
|
||
|
||
def construct_tasks(): | ||
tasks = {} | ||
for version_class in VERSIONS: | ||
tasks[f"jsquad-{version_class.VERSION}-{version_class.PROMPT_VERSION}"] = version_class | ||
return tasks |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
MODEL_ARGS="pretrained=abeja/gpt-neox-japanese-2.7b,device_map=auto,torch_dtype=auto" | ||
TASK="jsquad-1.2-0.2" | ||
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/abeja-gpt-neox-japanese-2.7b/result.jsquad-1.2.json" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
MODEL_ARGS="pretrained=abeja/gpt-neox-japanese-2.7b" | ||
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,xlsum_ja" | ||
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.2-0.2,xlsum_ja" | ||
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2,3,3,3,1" --device "cuda" --output_path "models/abeja-gpt-neox-japanese-2.7b/result.json" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"results": { | ||
"jsquad-1.2-0.2": { | ||
"exact_match": 15.803692030616839, | ||
"f1": 25.18326978234071 | ||
} | ||
}, | ||
"versions": { | ||
"jsquad-1.2-0.2": 1.2 | ||
}, | ||
"config": { | ||
"model": "hf-causal", | ||
"model_args": "pretrained=abeja/gpt-neox-japanese-2.7b,device_map=auto,torch_dtype=auto", | ||
"num_fewshot": 3, | ||
"batch_size": null, | ||
"device": "cuda", | ||
"no_cache": false, | ||
"limit": null, | ||
"bootstrap_iters": 100000, | ||
"description_dict": {} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
MODEL_ARGS="pretrained=cyberagent/open-calm-1b,device_map=auto,torch_dtype=auto" | ||
TASK="jsquad-1.2-0.2" | ||
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/cyberagent-open-calm-1b/result.jsquad-1.2.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to put inside
jsquad.py
and inherit from the v1 class. Then, we can see the diffs easily.