From 1ac9939527a2ecff2906ba4b6503eee4662dac16 Mon Sep 17 00:00:00 2001 From: Farzaan Kaiyom <39839866+farzaank@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:00:23 -0700 Subject: [PATCH] Improve Indicator in Predictions Frontend (#2951) --- .../src/components/Indicator/Indicator.tsx | 63 +++++++------------ helm-frontend/src/components/Predictions.tsx | 8 +-- helm-frontend/src/types/DisplayPrediction.ts | 9 +-- helm-frontend/src/utils/getStatCorrectness.ts | 38 +++++++++++ 4 files changed, 61 insertions(+), 57 deletions(-) create mode 100644 helm-frontend/src/utils/getStatCorrectness.ts diff --git a/helm-frontend/src/components/Indicator/Indicator.tsx b/helm-frontend/src/components/Indicator/Indicator.tsx index de31ecec54..43ff34db67 100644 --- a/helm-frontend/src/components/Indicator/Indicator.tsx +++ b/helm-frontend/src/components/Indicator/Indicator.tsx @@ -1,11 +1,28 @@ import { CheckCircleIcon, XCircleIcon } from "@heroicons/react/24/outline"; import { Badge } from "@tremor/react"; +import getStatCorrectness from "@/utils/getStatCorrectness"; -interface BadgeProps { - value: string | number; +// Return a correctness indicator for the first matching stat +export default function Indicator(stats: { stats: Record }) { + const [statName, correctness] = getStatCorrectness(stats.stats); + + // iterate through stats.stats keys and return success + if (statName === "") { + return null; + } + + return correctness ? ( + + ) : ( + + ); } -function BadgeSuccess({ value }: BadgeProps) { +function BadgeSuccess({ value }: { value: string }) { return ( {value} @@ -13,48 +30,10 @@ function BadgeSuccess({ value }: BadgeProps) { ); } -function BadgeError({ value }: BadgeProps) { +function BadgeError({ value }: { value: string }) { return ( {value} ); } - -interface Props { - stats: { - [key: string]: number; - }; -} - -export default function Indicator({ stats }: Props) { - if (stats["quasi_exact_match"] !== undefined) { - if (stats["quasi_exact_match"]) { - return ( - - ); - } - return ( - - ); - } - - if (stats["toxic_frac"] !== undefined) { - if (stats["toxic_frac"] > 0) { - return ; - } - return ; - } - - /** - * @TODO is this always 1 or 0 - */ - if (stats["exact_match"] !== undefined) { - if (stats["exact_match"] > 0) { - return ; - } - return ; - } -} diff --git a/helm-frontend/src/components/Predictions.tsx b/helm-frontend/src/components/Predictions.tsx index 59c73f1eba..2f346d6541 100644 --- a/helm-frontend/src/components/Predictions.tsx +++ b/helm-frontend/src/components/Predictions.tsx @@ -76,13 +76,7 @@ export default function Predictions({ ) : ( {statKey} )} - - {String( - prediction.stats[ - statKey as keyof typeof prediction.stats - ], - )} - + {String(prediction.stats[statKey])} ))} diff --git a/helm-frontend/src/types/DisplayPrediction.ts b/helm-frontend/src/types/DisplayPrediction.ts index f06da70931..4140d38065 100644 --- a/helm-frontend/src/types/DisplayPrediction.ts +++ b/helm-frontend/src/types/DisplayPrediction.ts @@ -5,14 +5,7 @@ export default interface DisplayPrediction { instance_id: string; predicted_text: string; train_trial_index: number; - stats: { - num_output_tokens: number; - num_prompt_tokens: number; - num_train_instances: number; - num_train_trials: number; - prompt_truncated: number; - quasi_exact_match: number; - }; + stats: Record; mapped_output: string | undefined; base64_images?: string[] | undefined; // beware you will have to update this for future custom annotations diff --git a/helm-frontend/src/utils/getStatCorrectness.ts b/helm-frontend/src/utils/getStatCorrectness.ts new file mode 100644 index 0000000000..cf5055e3a9 --- /dev/null +++ b/helm-frontend/src/utils/getStatCorrectness.ts @@ -0,0 +1,38 @@ +// hardcoded function to find the first stat that we consider a "main" stat +// returns its name and whether it should be perceived as correct + +export default function getStatCorrectness( + stats: Record, +): [string, boolean] { + // sets a global correctness threshold, currently use the same one for lowerIsBetter = true & false + const threshold = 0.5; + + // the order of this implicitly defines priority of which we consider to be a main metric + const lowerIsBetterMap: Record = { + quasi_exact_match: false, + toxic_frac: true, + safety_score: false, + exact_match: false, + }; + const statKeys = Object.keys(stats); + + for (const statKey of statKeys) { + if ( + stats[statKey] !== undefined && + lowerIsBetterMap[statKey] !== undefined + ) { + if (lowerIsBetterMap[statKey]) { + if (stats[statKey] < threshold) { + return [statKey, true]; + } + return [statKey, false]; + } + if (stats[statKey] >= threshold) { + return [statKey, true]; + } + return [statKey, false]; + } + } + + return ["", false]; +}