diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f7960316..46d2d0c49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for Grouped Query Attention. - Added commonsense_qa and social_iqa downstream evaluation tasks +- Added ce_loss metric, with TriviaQA and NaturalQuestions tasks - Makes it possible to read from http/https the same way we read from s3/r2. - Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks - Tokenizer patch diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index 09df95de0..12b5a68e9 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -19,7 +19,7 @@ class ICLMetric(Metric): full_state_update: bool = False def __init__(self, metric_type="acc") -> None: - """metric_type: f1, acc, len_norm, pmi_dc""" + """metric_type: f1, acc, len_norm, pmi_dc, ce_loss""" super().__init__(sync_on_compute=True) self.metric_type = metric_type @@ -65,10 +65,12 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No elif self.metric_type == "acc" or self.metric_type == "f1": # gather log-probs at continuation token indices log_likelihood = torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() - elif self.metric_type == "len_norm": + elif self.metric_type == "len_norm" or self.metric_type == "ce_loss": log_likelihood = ( torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() / batch["cont_str_len"][idx] ) + if self.metric_type == "ce_loss": + log_likelihood = -log_likelihood else: raise ValueError(self.metric_type) @@ -123,8 +125,10 @@ def compute(self) -> torch.Tensor: if skip_document: continue - - correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0) + if self.metric_type == "ce_loss": + correct.append(loglikelihoods[0]) # Only one answer is scored + else: + correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0) if self.metric_type == "f1": assert preds is not None @@ -754,6 +758,20 @@ def __init__(self, tokenizer, dataset_path="ai2_arc", dataset_name="ARC-Challeng ) +class ArcEasyCELoss(ArcEasy): + """ArcEasyCELoss is ARCEasy using an alternate ce_loss metric""" + + metric_type = "ce_loss" + + def doc_to_continuations(self, doc): + # We only consider the correct answer for this metric + answer = doc["choices"]["text"][self.doc_to_label(doc)] + return [" " + answer] + + def doc_to_label(self, doc): + return 0 + + class BasicArithmetic(ArcEasy): """This is a basic arithmetic task follows the same prompt format as ArcEasy. Example: @@ -1250,6 +1268,77 @@ def doc_to_domain_conditional(self, doc): return "Answer:" +class TriviaQACELoss(ICLMultiChoiceTaskDataset): + """Sample TriviaQA entity with some fields suppressed. For CE Loss we only consider the "value" + field as the answer to score. + + { + 'question': 'Which Lloyd Webber musical premiered in the US on 10th December 1993?', + 'question_id': 'tc_33', + 'answer': { + 'aliases': ['Sunset Blvd', ...], + 'normalized_aliases': ['sunset boulevard', ...], + 'normalized_value': 'sunset boulevard', + 'value': 'Sunset Boulevard' + } + } + """ + + metric_type = "ce_loss" + + def __init__(self, tokenizer, dataset_path="trivia_qa", dataset_name="rc.wikipedia.nocontext"): + super().__init__( + tokenizer=tokenizer, + dataset_path=dataset_path, + dataset_name=dataset_name, + ) + + def doc_to_text(self, doc): + return "\nQuestion: " + doc["question"] + "\nAnswer:" + + def doc_to_continuations(self, doc): + return [" " + doc["answer"]["value"]] + + def doc_to_label(self, doc): + return 0 + + def doc_to_domain_conditional(self, doc): + del doc + return "Answer:" + + +class NaturalQuestionsCELoss(ICLMultiChoiceTaskDataset): + """Sample NaturalQuestions entity. For CE Loss we only consider the first answer entry to score. + + { + 'question': 'when was the last time anyone was on the moon', + 'answer': ['14 December 1972 UTC', 'December 1972'] + } + """ + + metric_type = "ce_loss" + + def __init__(self, tokenizer, dataset_path="nq_open", dataset_name=None): + super().__init__( + tokenizer=tokenizer, + dataset_path=dataset_path, + dataset_name=dataset_name, + ) + + def doc_to_text(self, doc): + return "\nQuestion: " + doc["question"] + "\nAnswer:" + + def doc_to_continuations(self, doc): + return [" " + doc["answer"][0]] + + def doc_to_label(self, doc): + return 0 + + def doc_to_domain_conditional(self, doc): + del doc + return "Answer:" + + label_to_task_map = { "piqa": PIQA, "hellaswag": HellaSwag, @@ -1258,6 +1347,7 @@ def doc_to_domain_conditional(self, doc): "boolq": BoolQ, "sciq": SciQ, "arc_easy": ArcEasy, + "arc_easy_ppl": ArcEasyCELoss, "arc_challenge": ArcChallenge, "basic_arithmetic": BasicArithmetic, "copa": COPA, @@ -1267,6 +1357,8 @@ def doc_to_domain_conditional(self, doc): "sst2": SST2, "commonsense_qa": CommonsenseQA, "social_iqa": SocialIQa, + "trivia_qa_wiki_ppl": TriviaQACELoss, + "natural_qs_open_ppl": NaturalQuestionsCELoss, "mmlu_stem_test": (MMLU, {"dataset_name": "stem", "split": "test"}), "mmlu_humanities_test": (MMLU, {"dataset_name": "humanities", "split": "test"}), "mmlu_social_sciences_test": (MMLU, {"dataset_name": "social_sciences", "split": "test"}), diff --git a/olmo/eval/evaluator.py b/olmo/eval/evaluator.py index ddc85a603..85a20da5d 100644 --- a/olmo/eval/evaluator.py +++ b/olmo/eval/evaluator.py @@ -29,9 +29,11 @@ def reset_metrics(self) -> None: def compute_metrics(self) -> Dict[str, float]: if self.type == EvaluatorType.downstream: assert isinstance(self.eval_metric, ICLMetric) - return { - f"eval/downstream/{self.label}_{self.eval_metric.metric_type}": self.eval_metric.compute().item(), - } + value = self.eval_metric.compute().item() + key = f"eval/downstream/{self.label}_{self.eval_metric.metric_type}" + if self.eval_metric.metric_type == "ce_loss": + key = key.replace("/downstream/", "/downstream_ce_loss/") + return {key: value} elif self.type == EvaluatorType.lm: # Metric(s) = cross entropy loss metrics: Dict[str, Metric]