diff --git a/src/helm/benchmark/presentation/schema.py b/src/helm/benchmark/presentation/schema.py index f71816b6b4..f10fe2ecee 100644 --- a/src/helm/benchmark/presentation/schema.py +++ b/src/helm/benchmark/presentation/schema.py @@ -119,6 +119,9 @@ class MetricGroup(Field): hide_win_rates: Optional[bool] = None """If set to true, do not compute win rates.""" + aggregation_strategies: Optional[List[str]] = None + """List with values in {'win_rate','mean'} that correspond to aggregations""" + BY_METRIC = "by_metric" BY_GROUP = "by_group" diff --git a/src/helm/benchmark/presentation/summarize.py b/src/helm/benchmark/presentation/summarize.py index 82828ae5ba..0a0d3c50f6 100644 --- a/src/helm/benchmark/presentation/summarize.py +++ b/src/helm/benchmark/presentation/summarize.py @@ -251,7 +251,39 @@ def compute_aggregate_row_win_rates(table: Table, aggregation: str = "mean") -> return aggregate_win_rates +def compute_aggregate_row_means(table: Table) -> List[Optional[float]]: + """ + Computes the aggregate mean of each row across columns. + Returns a list of means, one per row, with None if a row was never meaningfully comparable (i.e., all + non-null values of the row are in columns we skip). + """ + + row_means: List[Optional[float]] = [] + + # check for all header cells where specified, that lower_is_better is consistent + orderings = [] + for elem in table.header: + orderings.append(elem.lower_is_better) + if len(set(orderings)) != 1: + raise Exception("Cannot mean columns with different values for lower_is_better") + + for row in table.rows: + total = 0.0 + count = 0 + for cell in row: + if cell.value is not None: + total += float(cell.value) + count += 1 + if count == 0: + row_means.append(None) + else: + row_means.append(total / count) + + return row_means + + AGGREGATE_WIN_RATE_COLUMN = 1 +AGGREGATION_STRATEGIES = ["mean", "win_rate"] class Summarizer: @@ -881,6 +913,7 @@ def create_group_table( sub_split: Optional[str] = None, bold_columns: bool = True, add_win_rate: bool = False, + aggregation_strategies: List[str] = [], ) -> Table: """ Create a table for where each row is an adapter (for which we have a set of runs) and columns are pairs of @@ -1063,21 +1096,53 @@ def _adapter_spec_sort_key(spec): table = Table(title=title, header=header, rows=rows, links=links, name=name) - if add_win_rate: - # add overall win rate as the second column - WIN_RATE_AGGREGATION = "mean" - win_rates = compute_aggregate_row_win_rates(table, aggregation=WIN_RATE_AGGREGATION) - description = "How many models this model outperform on average (over columns)." - table.header.insert( - AGGREGATE_WIN_RATE_COLUMN, - HeaderCell( - f"{WIN_RATE_AGGREGATION.capitalize()} win rate", - description=description, - lower_is_better=False, - ), - ) - for row, win_rate in zip(table.rows, win_rates): - row.insert(AGGREGATE_WIN_RATE_COLUMN, Cell(win_rate)) + if aggregation_strategies is None: + aggregation_strategies = ["win_rate"] + + # this preserves backwards compatibility for self.schema.name_to_metric_group[metric_group].hide_win_rates + # hide_win_rate is the inverse of add_win_rate here (see the function call for create_group_table) + hide_aggregation = not add_win_rate + if hide_aggregation: + aggregation_strategies = [] + + aggregate_header_cells: List[HeaderCell] = [] + aggregate_row_values: List[List[Optional[float]]] = [] + + for strategy in aggregation_strategies: + if strategy == "win_rate": + WIN_RATE_AGGREGATION = "mean" + win_rates = compute_aggregate_row_win_rates(table, aggregation=WIN_RATE_AGGREGATION) + description = "How many models this model outperforms on average (over columns)." + aggregate_header_cells.append( + HeaderCell( + f"{WIN_RATE_AGGREGATION.capitalize()} win rate", + description=description, + lower_is_better=False, + ) + ) + aggregate_row_values.append(win_rates) + elif strategy == "mean": + means = compute_aggregate_row_means(table) + description = "An average over columns representing the mean performance." + aggregate_header_cells.append( + HeaderCell( + "Mean performance", + description=description, + lower_is_better=table.header[0].lower_is_better, + ) + ) + aggregate_row_values.append(means) + else: + raise Exception( + f"Unknown aggregation strategy found: {strategy}. Please use one of: {AGGREGATION_STRATEGIES}" + ) + + for i in range(len(aggregate_header_cells)): + aggregate_header_cell = aggregate_header_cells[i] + aggregate_rows = aggregate_row_values[i] + table.header.insert(i + 1, aggregate_header_cell) + for row, row_val in zip(table.rows, aggregate_rows): + row.insert(i + 1, Cell(row_val)) if bold_columns: for i, header_cell in enumerate(table.header): @@ -1126,6 +1191,9 @@ def create_group_tables_by_metric_group(self, group: RunGroup) -> List[Table]: if len(adapter_to_runs) > 0: for metric_group in all_metric_groups: display_name = self.schema.name_to_metric_group[metric_group].get_short_display_name() + aggregate_strategies: List[str] = ( + self.schema.name_to_metric_group[metric_group].aggregation_strategies or [] + ) table = self.create_group_table( name=metric_group, title=display_name, @@ -1133,6 +1201,7 @@ def create_group_tables_by_metric_group(self, group: RunGroup) -> List[Table]: columns=[(subgroup, metric_group) for subgroup in subgroups], is_scenario_table=False, add_win_rate=not self.schema.name_to_metric_group[metric_group].hide_win_rates, + aggregation_strategies=aggregate_strategies, ) tables.append(table) return tables diff --git a/src/helm/benchmark/static/schema_safety.yaml b/src/helm/benchmark/static/schema_safety.yaml index 32239777fc..553826ff82 100644 --- a/src/helm/benchmark/static/schema_safety.yaml +++ b/src/helm/benchmark/static/schema_safety.yaml @@ -106,6 +106,9 @@ perturbations: [] metric_groups: - name: accuracy display_name: Accuracy + aggregation_strategies: + - win_rate + - mean metrics: - name: ${main_name} split: ${main_split}