Skip to content

Commit

Permalink
Improve Indicator in Predictions Frontend (#2951)
Browse files Browse the repository at this point in the history
  • Loading branch information
farzaank committed Sep 4, 2024
1 parent 66ac81e commit 1ac9939
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 57 deletions.
63 changes: 21 additions & 42 deletions helm-frontend/src/components/Indicator/Indicator.tsx
Original file line number Diff line number Diff line change
@@ -1,60 +1,39 @@
import { CheckCircleIcon, XCircleIcon } from "@heroicons/react/24/outline";
import { Badge } from "@tremor/react";
import getStatCorrectness from "@/utils/getStatCorrectness";

interface BadgeProps {
value: string | number;
// Return a correctness indicator for the first matching stat
export default function Indicator(stats: { stats: Record<string, number> }) {
const [statName, correctness] = getStatCorrectness(stats.stats);

// iterate through stats.stats keys and return success
if (statName === "") {
return null;
}

return correctness ? (
<BadgeSuccess
value={`${statName.replace(/_/g, " ")}: ${stats.stats[statName]}`}
/>
) : (
<BadgeError
value={`${statName.replace(/_/g, " ")}: ${stats.stats[statName]}`}
/>
);
}

function BadgeSuccess({ value }: BadgeProps) {
function BadgeSuccess({ value }: { value: string }) {
return (
<Badge icon={CheckCircleIcon} color="green">
{value}
</Badge>
);
}

function BadgeError({ value }: BadgeProps) {
function BadgeError({ value }: { value: string }) {
return (
<Badge icon={XCircleIcon} color="red">
{value}
</Badge>
);
}

interface Props {
stats: {
[key: string]: number;
};
}

export default function Indicator({ stats }: Props) {
if (stats["quasi_exact_match"] !== undefined) {
if (stats["quasi_exact_match"]) {
return (
<BadgeSuccess
value={`quasi exact match: ${stats["quasi_exact_match"]}`}
/>
);
}
return (
<BadgeError value={`quasi exact match: ${stats["quasi_exact_match"]}`} />
);
}

if (stats["toxic_frac"] !== undefined) {
if (stats["toxic_frac"] > 0) {
return <BadgeError value={`toxic frac: ${stats["toxic_frac"]}`} />;
}
return <BadgeSuccess value={`toxic frac: ${stats["toxic_frac"]}`} />;
}

/**
* @TODO is this always 1 or 0
*/
if (stats["exact_match"] !== undefined) {
if (stats["exact_match"] > 0) {
return <BadgeSuccess value={`exact match: ${stats["exact_match"]}`} />;
}
return <BadgeError value={`exact match: ${stats["exact_match"]}`} />;
}
}
8 changes: 1 addition & 7 deletions helm-frontend/src/components/Predictions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,7 @@ export default function Predictions({
) : (
<span>{statKey}</span>
)}
<span>
{String(
prediction.stats[
statKey as keyof typeof prediction.stats
],
)}
</span>
<span>{String(prediction.stats[statKey])}</span>
</ListItem>
))}
</List>
Expand Down
9 changes: 1 addition & 8 deletions helm-frontend/src/types/DisplayPrediction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,7 @@ export default interface DisplayPrediction {
instance_id: string;
predicted_text: string;
train_trial_index: number;
stats: {
num_output_tokens: number;
num_prompt_tokens: number;
num_train_instances: number;
num_train_trials: number;
prompt_truncated: number;
quasi_exact_match: number;
};
stats: Record<string, number>;
mapped_output: string | undefined;
base64_images?: string[] | undefined;
// beware you will have to update this for future custom annotations
Expand Down
38 changes: 38 additions & 0 deletions helm-frontend/src/utils/getStatCorrectness.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// hardcoded function to find the first stat that we consider a "main" stat
// returns its name and whether it should be perceived as correct

export default function getStatCorrectness(
stats: Record<string, number>,
): [string, boolean] {
// sets a global correctness threshold, currently use the same one for lowerIsBetter = true & false
const threshold = 0.5;

// the order of this implicitly defines priority of which we consider to be a main metric
const lowerIsBetterMap: Record<string, boolean> = {
quasi_exact_match: false,
toxic_frac: true,
safety_score: false,
exact_match: false,
};
const statKeys = Object.keys(stats);

for (const statKey of statKeys) {
if (
stats[statKey] !== undefined &&
lowerIsBetterMap[statKey] !== undefined
) {
if (lowerIsBetterMap[statKey]) {
if (stats[statKey] < threshold) {
return [statKey, true];
}
return [statKey, false];
}
if (stats[statKey] >= threshold) {
return [statKey, true];
}
return [statKey, false];
}
}

return ["", false];
}

0 comments on commit 1ac9939

Please sign in to comment.