diff --git a/src/helm/benchmark/presentation/summarize.py b/src/helm/benchmark/presentation/summarize.py index 82828ae5ba..d6cda11fdb 100644 --- a/src/helm/benchmark/presentation/summarize.py +++ b/src/helm/benchmark/presentation/summarize.py @@ -226,17 +226,27 @@ def compute_aggregate_row_win_rates(table: Table, aggregation: str = "mean") -> """ assert aggregation in ["mean", "median"] win_rates_per_row: List[List[float]] = [[] for _ in table.rows] - for i, header_cell in enumerate(table.header): + for column_index, header_cell in enumerate(table.header): lower_is_better = header_cell.lower_is_better if lower_is_better is None: # column does not have a meaningful ordering continue - - values = [(row[i].value, j) for j, row in enumerate(table.rows) if row[i].value is not None] - if len(values) < 2: # don't rank a single model + value_to_count: Dict[float, int] = defaultdict(int) + for row in table.rows: + value = row[column_index].value + if value is not None: + value_to_count[value] += 1 + value_to_wins: Dict[float, float] = {} + acc_count = 0 + for value, value_count in sorted(value_to_count.items(), reverse=lower_is_better): + value_to_wins[value] = acc_count + ((value_count - 1) / 2) + acc_count += value_count + total_count = acc_count + if total_count < 2: continue - for wins, (v, j) in enumerate(sorted(values, reverse=lower_is_better)): - win_rate = wins / (len(values) - 1) # normalize to [0, 1] - win_rates_per_row[j].append(win_rate) + for row_index, row in enumerate(table.rows): + value = row[column_index].value + if value is not None: + win_rates_per_row[row_index].append(value_to_wins[row[column_index].value] / (total_count - 1)) # Note: the logic up to here is somewhat general as it simply computes win rates across columns for each row. # Here, we simply average these win rates but we might want some more involved later (e.g., weighted average). diff --git a/src/helm/benchmark/presentation/test_summarize.py b/src/helm/benchmark/presentation/test_summarize.py index cb6552989f..893189fd39 100644 --- a/src/helm/benchmark/presentation/test_summarize.py +++ b/src/helm/benchmark/presentation/test_summarize.py @@ -163,4 +163,21 @@ def test_compute_win_rates_ties(): ] rows = [[Cell(value) for value in row_values] for row_values in values] table = Table(title="Test Table", header=header, rows=rows) - assert compute_aggregate_row_win_rates(table) == [0.0, 0.25, 0.5, 0.75, 1.0] + assert compute_aggregate_row_win_rates(table) == [0.25, 0.25, 0.25, 0.75, 1.0] + + +def test_compute_win_rates_lower_is_better(): + header = [ + HeaderCell(value="Model"), + HeaderCell(value="Scenario A", lower_is_better=False), + ] + values = [ + ["Model A", 1], + ["Model B", 2], + ["Model C", 3], + ["Model D", 4], + ["Model E", 5], + ] + rows = [[Cell(value) for value in row_values] for row_values in values] + table = Table(title="Test Table", header=header, rows=rows) + assert compute_aggregate_row_win_rates(table) == [0, 0.25, 0.5, 0.75, 1]