Skip to content

Commit

Permalink
Merge pull request EleutherAI#6 from cjlovering/master
Browse files Browse the repository at this point in the history
Update with new PR
  • Loading branch information
StellaAthena authored Apr 27, 2022
2 parents e06040c + 0b64f89 commit a4a472c
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 62 deletions.
158 changes: 143 additions & 15 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,21 @@ class PromptSourceTask(Task):
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
"""

CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])

def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
CONFIGURED_RANKED_CHOICE_PS_METRICS = set(["Accuracy"])
CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE"])
SPLIT = None

def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
prompt=None,
save_examples=True,
):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
self.save_examples = save_examples

def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
Expand Down Expand Up @@ -752,24 +762,23 @@ def process_results(self, doc, results):

for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
metric in self.CONFIGURED_RANKED_CHOICE_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = pred == target
# TODO: Add metrics here.
return out
else:
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
pred = results[0].strip()
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
metric in self.CONFIGURED_GENERATION_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU":
out["bleu"] = (target, pred)
if metric == "ROUGE":
elif metric == "ROUGE":
# TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
Expand All @@ -778,15 +787,21 @@ def process_results(self, doc, results):
rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores}
print(out)
return out

# TODO: Wrap process results s.t. override impl do not
# override the save examples.
if self.save_examples:
example = {
"pred": pred,
"target": target,
"answer_choices_list": answer_choices_list,
}
return out, example
return out

def higher_is_better(self):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = True
if metric == "BLEU":
Expand All @@ -813,9 +828,6 @@ def higher_is_better(self):
def aggregation(self):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = mean
if metric == "BLEU":
Expand All @@ -839,6 +851,122 @@ def aggregation(self):
out["rougeLsum_fmeasure"] = mean
return out

def fewshot_examples(self, k, rnd):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return self._get_fewshot_examples(self._training_docs, k, rnd)

def _get_fewshot_examples(self, docs, k, rnd):
fewshot_idx = rnd.sample(list(np.arange(len(docs))), k)
return [docs[idx] for idx in fewshot_idx], [int(idx) for idx in fewshot_idx]

@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)

description = description + "\n\n" if description else ""

if num_fewshot == 0:
labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train"
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
if self.has_validation_docs():
fewshotsource = "val"
elif self.test_docs():
fewshotsource = "test"

fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
)
fewshotex, fewshotidx = [
(shot, idx)
for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc
]
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = (
fewshotex[:num_fewshot],
fewshotidx[:num_fewshot],
)

labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)

example = self.doc_to_text(doc)
ctx = description + labeled_examples + example
return (
ctx,
{
"fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource,
"fewshot_num": num_fewshot,
"ctx": ctx,
},
)

def get_logging_info(self):
return {
"fixed_answer_choice_list": self.prompt.get_fixed_answer_choices_list(),
"dataset_path": self.DATASET_PATH,
"dataset_name": self.DATASET_NAME,
"subset": self.SPLIT,
"prompt_name": self.prompt.get_name(),
"prompt_id": self.prompt.get_id(),
"prompt_jinja": self.prompt.jinja,
"prompt_original_task": self.prompt.metadata.original_task,
# Placeholder for comment in post-processing.
"comment": "",
}


class MultipleChoiceTask(Task):
def doc_to_target(self, doc):
Expand Down
67 changes: 50 additions & 17 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,6 @@ def evaluate(

# get lists of each type of request
for task_prompt_name, task in task_dict_items:
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue

versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
Expand All @@ -188,7 +184,7 @@ def evaluate(
raise RuntimeError("Task has neither test_docs nor validation_docs")

# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs = list(task_doc_func())
task_docs = list(enumerate(list(task_doc_func())))
rnd = random.Random()
rnd.seed(42)
rnd.shuffle(task_docs)
Expand All @@ -199,14 +195,17 @@ def evaluate(
else ""
)

for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
for doc_id, (original_doc_id, doc) in enumerate(
itertools.islice(task_docs, 0, limit)
):
if task.invalid_doc_for_prompt(doc):
continue

docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context(
ctx, fewshotex_logging_info = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
Expand All @@ -215,7 +214,7 @@ def evaluate(
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append(
(i, task_prompt_name, doc, doc_id)
(i, task_prompt_name, doc, doc_id, fewshotex_logging_info)
)

# all responses for each (task, doc)
Expand All @@ -234,33 +233,57 @@ def evaluate(
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]

for resp, (i, task_prompt_name, doc, doc_id) in zip(
for resp, (i, task_prompt_name, doc, doc_id, fewshotex_logging_info) in zip(
resps, requests_origin[reqtype]
):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp))
process_res_queue[(task_prompt_name, doc_id)].append(
(i, resp, fewshotex_logging_info)
)

vals = collections.defaultdict(list)

# unpack results and sort back in order and return control to Task
for (task_prompt_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]
examples = []
for (task_prompt_name, doc_id), per_doc_requests in process_res_queue.items():
per_doc_requests.sort(key=lambda x: x[0])
per_doc_results = [x[1] for x in per_doc_requests]
fewshot_logging_info = [x[2] for x in per_doc_requests][0]

task = task_dict[task_prompt_name]
doc = docs[(task_prompt_name, doc_id)]

metrics = task.process_results(doc, requests)
output = task.process_results(doc, per_doc_results)
if task.save_examples:
metrics, example = output
example.update(fewshot_logging_info)
example.update(task.get_logging_info())
examples.append(example)
else:
metrics = output
example = fewshot_logging_info
example.update(task.get_logging_info())
examples.append(example)

for metric, value in metrics.items():
vals[(task_prompt_name, metric)].append(value)

# aggregate results
metric_results = []
for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+")

results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_prompt_name]
results[task_prompt_name][metric] = task.aggregation()[metric](items)

_metric_results = {
"task_name": task_name,
"prompt_name": prompt_name,
metric: task.aggregation()[metric](items),
**task.get_logging_info(),
}

# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(
Expand All @@ -271,8 +294,18 @@ def evaluate(
)
if stderr is not None:
results[task_prompt_name][metric + "_stderr"] = stderr(items)

return {"results": dict(results), "versions": dict(versions)}
_metric_results[metric + "_stderr"] = stderr(items)
metric_results.append(_metric_results)

return {
# List of results that tracks the averages per model and prompt.
"results": metric_results,
"versions": dict(versions),
# List of all prompt x doc examples with additional information in it.
"examples": examples,
# Original results used for generating the table when running this file.
"table_results": dict(results),
}


def make_table(result_dict):
Expand All @@ -293,7 +326,7 @@ def make_table(result_dict):
]

values = []
for k, dic in result_dict["results"].items():
for k, dic in result_dict["table_results"].items():
version = result_dict["versions"][k]
for m, v in dic.items():
if m.endswith("_stderr"):
Expand Down
19 changes: 6 additions & 13 deletions lm_eval/tasks/coqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,25 +118,18 @@ def process_results(self, doc, results):
"""
target = self.doc_to_target(doc).strip()
pred = results[0].strip().split("\n")[0]
print("*" * 80)
print(f"DOC: {doc}")
# print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(f"PRED: {pred} END PRED")
print("*" * 80)

# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)

# TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred)

return {
out = {
"f1": scores["f1"],
"em": scores["em"],
}

if self.save_examples:
example = {"target": target, "pred": pred}
return out, example
return out

def higher_is_better(self):
return {
"f1": True,
Expand Down
Loading

0 comments on commit a4a472c

Please sign in to comment.