Skip to content

Commit

Permalink
Improve frontend display of adapters, stats, requests and predictions (
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Feb 26, 2024
1 parent 0838565 commit 514b28e
Show file tree
Hide file tree
Showing 14 changed files with 142 additions and 159 deletions.
9 changes: 8 additions & 1 deletion helm-frontend/src/components/InstanceData.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@ import type DisplayRequest from "@/types/DisplayRequest";
import type DisplayPrediction from "@/types/DisplayPrediction";
import Predictions from "@/components/Predictions";
import References from "@/components/References";
import type MetricFieldMap from "@/types/MetricFieldMap";
import Preview from "@/components/Preview";

interface Props {
instance: Instance;
requests: DisplayRequest[];
predictions: DisplayPrediction[];
metricFieldMap: MetricFieldMap;
}

export default function InstanceData({
instance,
requests,
predictions,
metricFieldMap,
}: Props) {
return (
<div className="border p-4">
Expand All @@ -27,7 +30,11 @@ export default function InstanceData({
<References references={instance.references} />
) : null}
{predictions && requests ? (
<Predictions predictions={predictions} requests={requests} />
<Predictions
predictions={predictions}
requests={requests}
metricFieldMap={metricFieldMap}
/>
) : null}
</div>
);
Expand Down
16 changes: 7 additions & 9 deletions helm-frontend/src/components/MetricsList.tsx
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import type Metric from "@/types/Metric";
import type MetricField from "@/types/MetricField";
import type MetricFieldMap from "@/types/MetricFieldMap";
import type MetricGroup from "@/types/MetricGroup";

interface Props {
metrics: Metric[];
metricFieldMap: MetricFieldMap;
metricGroups: MetricGroup[];
}

export default function MetricList({ metrics, metricGroups }: Props) {
const metricNameToMetric = new Map<string, Metric>();
metrics.forEach((metric) => metricNameToMetric.set(metric.name, metric));

export default function MetricList({ metricFieldMap, metricGroups }: Props) {
// Only count metrics that have a group and are displayed
// i.e. don't count "orphaned" metrics
// Also, don't double-count metrics that appear in multiple groups
const groupedMetricNames = new Set<string>();

const metricGroupsWithMetrics: [MetricGroup, Metric[]][] = [];
const metricGroupsWithMetrics: [MetricGroup, MetricField[]][] = [];
metricGroups.forEach((metricGroup) => {
const metricGroupMetrics: Metric[] = [];
const metricGroupMetrics: MetricField[] = [];
metricGroup.metrics.forEach((metricField) => {
const maybeMetric = metricNameToMetric.get(metricField.name);
const maybeMetric = metricFieldMap[metricField.name];
if (maybeMetric) {
metricGroupMetrics.push(maybeMetric);
groupedMetricNames.add(maybeMetric.name);
Expand Down
169 changes: 56 additions & 113 deletions helm-frontend/src/components/Predictions.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { useState } from "react";
import type DisplayPrediction from "@/types/DisplayPrediction";
import type DisplayRequest from "@/types/DisplayRequest";
import type MetricFieldMap from "@/types/MetricFieldMap";
import Indicator from "@/components/Indicator";
import Request from "@/components/Request";
import Preview from "@/components/Preview";
Expand All @@ -9,138 +9,81 @@ import { List, ListItem } from "@tremor/react";
type Props = {
predictions: DisplayPrediction[];
requests: DisplayRequest[];
metricFieldMap: MetricFieldMap;
};

/**
* @SEE https://github.com/stanford-crfm/helm/blob/cffe38eb2c814d054c778064859b6e1551e5e106/src/helm/benchmark/static/benchmarking.js#L583-L679
*/
export default function Predictions({ predictions, requests }: Props) {
const [isOpen, setIsOpen] = useState(false);

const toggleAccordion = () => {
setIsOpen(!isOpen);
};

export default function Predictions({
predictions,
requests,
metricFieldMap,
}: Props) {
if (predictions.length < 1) {
return null;
}

if (predictions && predictions[0] && predictions[0].base64_images) {
return (
<div>
<div className="flex flex-wrap justify-start items-start">
{predictions.map((prediction, idx) => (
<div className="w-full" key={idx}>
{predictions.length > 1 ? <h2>Trial {idx}</h2> : null}
<div className="mt-2 w-full">
<h3>
<span className="mr-4">Prediction image</span>
</h3>
<div>
{prediction &&
prediction.base64_images &&
prediction.base64_images[0] ? (
<img
src={"data:image;base64," + prediction.base64_images[0]}
alt="Base64 Image"
/>
) : null}
</div>
</div>
<div className="accordion-wrapper">
<button
className="accordion-title p-5 bg-gray-100 hover:bg-gray-200 w-full text-left"
onClick={toggleAccordion}
>
<h3 className="text-lg font-medium text-gray-900">
Prompt Details
</h3>
</button>

{isOpen && (
<div className="accordion-content p-5 border shadow-lg rounded-md bg-white">
<div className="mt-3 text-left">
<div className="overflow-auto">
<Request request={requests[idx]} />
</div>
<List>
{Object.keys(prediction.stats).map((statKey, idx) => (
<ListItem key={idx} className="mt-2">
<span>{statKey}:</span>
<span>
{String(
prediction.stats[
statKey as keyof typeof prediction.stats
],
)}
</span>
</ListItem>
))}
</List>
</div>
</div>
)}
</div>
</div>
))}
</div>
</div>
);
}

return (
<div>
<div className="flex flex-wrap justify-start items-start">
{predictions.map((prediction, idx) => (
<div className="w-full" key={idx}>
{predictions.length > 1 ? <h2>Trial {idx}</h2> : null}
<div className="mt-2 w-full">
<h3>
<span className="mr-4">Prediction raw text</span>
<Indicator stats={prediction.stats} />
</h3>
<Preview value={prediction.predicted_text} />
{prediction.mapped_output ? (
{prediction.base64_images && prediction.base64_images.length ? (
<>
<h3>Prediction mapped output</h3>
<Preview value={String(prediction.mapped_output)} />
<h3 className="mr-4">Prediction image</h3>
{prediction.base64_images.map((base64_image) => (
<img
src={"data:image;base64," + base64_image}
alt="Base64 Image"
/>
))}
</>
) : (
<>
<h3>
<span className="mr-4">Prediction raw text</span>
<Indicator stats={prediction.stats} />
</h3>
<Preview value={prediction.predicted_text} />
{prediction.mapped_output ? (
<>
<h3>Prediction mapped output</h3>
<Preview value={String(prediction.mapped_output)} />
</>
) : null}
</>
) : null}
</div>
<div className="accordion-wrapper">
<button
className="accordion-title p-5 bg-gray-100 hover:bg-gray-200 w-full text-left"
onClick={toggleAccordion}
>
<h3 className="text-lg font-medium text-gray-900">
Prompt Details
</h3>
</button>

{isOpen && (
<div className="accordion-content p-5 border shadow-lg rounded-md bg-white">
<div className="mt-3 text-left">
<div className="overflow-auto">
<Request request={requests[idx]} />
</div>
<List>
{Object.keys(prediction.stats).map((statKey, idx) => (
<ListItem key={idx} className="mt-2">
<span>{statKey}:</span>
<span>
{String(
prediction.stats[
statKey as keyof typeof prediction.stats
],
)}
</span>
</ListItem>
))}
</List>
</div>
</div>
)}
</div>
<h3>Metrics</h3>
<List>
{Object.keys(prediction.stats).map((statKey, idx) => (
<ListItem key={idx}>
{metricFieldMap[statKey] ? (
<span title={metricFieldMap[statKey].description}>
{metricFieldMap[statKey].display_name}
</span>
) : (
<span>{statKey}</span>
)}
<span>
{String(
prediction.stats[
statKey as keyof typeof prediction.stats
],
)}
</span>
</ListItem>
))}
</List>
<details className="collapse collapse-arrow border rounded-md bg-white">
<summary className="collapse-title">Request details</summary>
<div className="collapse-content">
<Request request={requests[idx]} />
</div>
</details>
</div>
))}
</div>
Expand Down
2 changes: 1 addition & 1 deletion helm-frontend/src/components/Preview.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export default function Preview({ value }: Props) {
onMouseOut={() => setShowButton(false)}
className="relative"
>
<div className="bg-base-200 p-2 block overflow-auto w-full max-h-72 mb-2 whitespace-pre-wrap">
<div className="bg-base-200 p-2 block overflow-auto w-full max-h-[36rem] mb-2 whitespace-pre-wrap">
{value}
</div>

Expand Down
16 changes: 12 additions & 4 deletions helm-frontend/src/components/StatNameDisplay.tsx
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import type Stat from "@/types/Stat";
import underscoreToTitle from "@/utils/underscoreToTitle";
import type MetricFieldMap from "@/types/MetricFieldMap";

interface Props {
stat: Stat;
metricFieldMap: MetricFieldMap;
}

export default function StatNameDisplay({ stat }: Props) {
export default function StatNameDisplay({ stat, metricFieldMap }: Props) {
const value = `${
stat.name.split !== undefined ? ` on ${stat.name.split}` : ""
}${stat.name.sub_split !== undefined ? `/${stat.name.sub_split}` : ""}${
stat.name.perturbation !== undefined
? ` with ${stat.name.perturbation.name}`
: " original"
}`;
return (
return metricFieldMap[stat.name.name] ? (
<span title={metricFieldMap[stat.name.name].description}>
<strong>
{metricFieldMap[stat.name.name].display_name || stat.name.name}
</strong>
{value}
</span>
) : (
<span>
<strong>{underscoreToTitle(stat.name.name)}</strong>
<strong>{stat.name.name}</strong>
{value}
</span>
);
Expand Down
39 changes: 31 additions & 8 deletions helm-frontend/src/routes/Run.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import getDisplayPredictionsByName from "@/services/getDisplayPredictionsByName"
import type DisplayPredictionsMap from "@/types/DisplayPredictionsMap";
import getScenarioByName from "@/services/getScenarioByName";
import type Scenario from "@/types/Scenario";
import type AdapterFieldMap from "@/types/AdapterFieldMap";
import type MetricFieldMap from "@/types/MetricFieldMap";
import { getRunSpecByNameUrl } from "@/services/getRunSpecByName";
import { getScenarioStateByNameUrl } from "@/services/getScenarioStateByName";
import Tab from "@/components/Tab";
Expand Down Expand Up @@ -51,6 +53,8 @@ export default function Run() {
const [totalMetricsPages, setTotalMetricsPages] = useState<number>(1);
const [model, setModel] = useState<Model | undefined>();
const [scenario, setScenario] = useState<Scenario | undefined>();
const [adapterFieldMap, setAdapterFieldMap] = useState<AdapterFieldMap>({});
const [metricFieldMap, setMetricFieldMap] = useState<MetricFieldMap>({});
const [searchTerm, setSearchTerm] = useState("");

useEffect(() => {
Expand Down Expand Up @@ -121,6 +125,20 @@ export default function Run() {
}, {} as DisplayRequestsMap),
);
const schema = await getSchema(signal);

setMetricFieldMap(
schema.metrics.reduce((acc, cur) => {
acc[cur.name] = cur;
return acc;
}, {} as MetricFieldMap),
);
setAdapterFieldMap(
schema.adapter.reduce((acc, cur) => {
acc[cur.name] = cur;
return acc;
}, {} as AdapterFieldMap),
);

setModel(
schema.models.find(
(m) =>
Expand Down Expand Up @@ -207,7 +225,14 @@ export default function Run() {
<List className="grid md:grid-cols-2 lg:grid-cols-3 gap-x-8">
{Object.entries(runSpec.adapter_spec).map(([key, value], idx) => (
<ListItem className={idx < 3 ? "!border-0" : ""}>
<strong className="mr-1">{`${key}: `}</strong>
<strong
className="mr-1"
title={
adapterFieldMap[key]
? adapterFieldMap[key].description
: undefined
}
>{`${key}: `}</strong>
<span className="overflow-x-auto">{value}</span>
</ListItem>
))}
Expand Down Expand Up @@ -241,6 +266,7 @@ export default function Run() {
instance={instance}
requests={displayRequestsMap[instance.id]}
predictions={displayPredictionsMap[instance.id]}
metricFieldMap={metricFieldMap}
/>
))}
</div>
Expand Down Expand Up @@ -300,13 +326,10 @@ export default function Run() {
if (key === "name") {
return (
<td key={key}>
<StatNameDisplay stat={stat} />
<div className="text-sm text-gray-500">
{
/* eslint-disable-next-line @typescript-eslint/no-unsafe-member-access */
value.name
}
</div>
<StatNameDisplay
stat={stat}
metricFieldMap={metricFieldMap}
/>
</td>
);
}
Expand Down
Loading

0 comments on commit 514b28e

Please sign in to comment.