From 02982a47edd6e14d3e4c2b4ab222a61d24187c60 Mon Sep 17 00:00:00 2001 From: "Brian W. Goldman" <2237679+brianwgoldman@users.noreply.github.com> Date: Tue, 9 Jan 2024 11:29:00 -0700 Subject: [PATCH] Fix breakage caused by #2172 (#2194) --- src/helm/benchmark/metrics/basic_metrics.py | 11 +++++++---- src/helm/benchmark/metrics/metric.py | 4 +++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/helm/benchmark/metrics/basic_metrics.py b/src/helm/benchmark/metrics/basic_metrics.py index 26e3b7ac5f..f4a17dc065 100644 --- a/src/helm/benchmark/metrics/basic_metrics.py +++ b/src/helm/benchmark/metrics/basic_metrics.py @@ -243,11 +243,14 @@ def compute_logprob_and_length(request_state: RequestState, window_service: Wind raise ValueError(f"Unknown adapter method: {adapter_spec.method}") stats: List[Stat] = [] - for request_state in reference_request_states: - stats.extend( - compute_request_state_metrics(self.efficiency_metric, adapter_spec, request_state, metric_service) - ) + general_metrics: Dict[MetricName, Stat] = {} + for request_state in reference_request_states: + for stat in compute_request_state_metrics( + self.efficiency_metric, adapter_spec, request_state, metric_service + ): + merge_stat(general_metrics, stat) + stats.extend(general_metrics.values()) max_prob = np.max(scipy.special.softmax(reference_scores)) # Multiple references may attain the same maximal score; in such cases, diff --git a/src/helm/benchmark/metrics/metric.py b/src/helm/benchmark/metrics/metric.py index 464c31fd8c..7d41eab707 100644 --- a/src/helm/benchmark/metrics/metric.py +++ b/src/helm/benchmark/metrics/metric.py @@ -288,7 +288,9 @@ def compute_worst_case_metrics(self, per_instance_stats: Dict[Instance, List[Sta for stat in stats: # go through all the perturbations of the instance and merge relevant stats perturbation = stat.name.perturbation if perturbation is None: - assert original_stat is None # we should only have one original stat + assert ( + original_stat is None + ), f"For {metric_name} got both {original_stat} and {stat}" # we should only have one original stat original_stat = stat else: if perturbation.robustness: