Skip to content

Commit

Permalink
Handle ties in win rate computation (#3001)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Sep 18, 2024
1 parent ee6f10b commit 1f6f4b2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
24 changes: 17 additions & 7 deletions src/helm/benchmark/presentation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,27 @@ def compute_aggregate_row_win_rates(table: Table, aggregation: str = "mean") ->
"""
assert aggregation in ["mean", "median"]
win_rates_per_row: List[List[float]] = [[] for _ in table.rows]
for i, header_cell in enumerate(table.header):
for column_index, header_cell in enumerate(table.header):
lower_is_better = header_cell.lower_is_better
if lower_is_better is None: # column does not have a meaningful ordering
continue

values = [(row[i].value, j) for j, row in enumerate(table.rows) if row[i].value is not None]
if len(values) < 2: # don't rank a single model
value_to_count: Dict[float, int] = defaultdict(int)
for row in table.rows:
value = row[column_index].value
if value is not None:
value_to_count[value] += 1
value_to_wins: Dict[float, float] = {}
acc_count = 0
for value, value_count in sorted(value_to_count.items(), reverse=lower_is_better):
value_to_wins[value] = acc_count + ((value_count - 1) / 2)
acc_count += value_count
total_count = acc_count
if total_count < 2:
continue
for wins, (v, j) in enumerate(sorted(values, reverse=lower_is_better)):
win_rate = wins / (len(values) - 1) # normalize to [0, 1]
win_rates_per_row[j].append(win_rate)
for row_index, row in enumerate(table.rows):
value = row[column_index].value
if value is not None:
win_rates_per_row[row_index].append(value_to_wins[row[column_index].value] / (total_count - 1))

# Note: the logic up to here is somewhat general as it simply computes win rates across columns for each row.
# Here, we simply average these win rates but we might want some more involved later (e.g., weighted average).
Expand Down
19 changes: 18 additions & 1 deletion src/helm/benchmark/presentation/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,21 @@ def test_compute_win_rates_ties():
]
rows = [[Cell(value) for value in row_values] for row_values in values]
table = Table(title="Test Table", header=header, rows=rows)
assert compute_aggregate_row_win_rates(table) == [0.0, 0.25, 0.5, 0.75, 1.0]
assert compute_aggregate_row_win_rates(table) == [0.25, 0.25, 0.25, 0.75, 1.0]


def test_compute_win_rates_lower_is_better():
header = [
HeaderCell(value="Model"),
HeaderCell(value="Scenario A", lower_is_better=False),
]
values = [
["Model A", 1],
["Model B", 2],
["Model C", 3],
["Model D", 4],
["Model E", 5],
]
rows = [[Cell(value) for value in row_values] for row_values in values]
table = Table(title="Test Table", header=header, rows=rows)
assert compute_aggregate_row_win_rates(table) == [0, 0.25, 0.5, 0.75, 1]

0 comments on commit 1f6f4b2

Please sign in to comment.