Skip to content

Commit fed0f8c

Browse files
committed
Aggregation improvements for the legend
1 parent 78c789a commit fed0f8c

File tree

4 files changed

+143
-45
lines changed

4 files changed

+143
-45
lines changed

apps/webapp/app/components/code/QueryResultsChart.tsx

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ import { memo, useMemo } from "react";
33
import type { ChartConfig } from "~/components/primitives/charts/Chart";
44
import { Chart } from "~/components/primitives/charts/ChartCompound";
55
import { Paragraph } from "../primitives/Paragraph";
6-
import { AggregationType, ChartConfiguration } from "../metrics/QueryWidget";
6+
import type { AggregationType, ChartConfiguration } from "../metrics/QueryWidget";
7+
import { aggregateValues } from "../primitives/charts/aggregation";
78
import { getRunStatusHexColor } from "~/components/runs/v3/TaskRunStatus";
89
import { getSeriesColor } from "./chartColors";
910

@@ -671,25 +672,6 @@ function toNumber(value: unknown): number {
671672
return 0;
672673
}
673674

674-
/**
675-
* Aggregate an array of numbers using the specified aggregation function
676-
*/
677-
function aggregateValues(values: number[], aggregation: AggregationType): number {
678-
if (values.length === 0) return 0;
679-
switch (aggregation) {
680-
case "sum":
681-
return values.reduce((a, b) => a + b, 0);
682-
case "avg":
683-
return values.reduce((a, b) => a + b, 0) / values.length;
684-
case "count":
685-
return values.length;
686-
case "min":
687-
return Math.min(...values);
688-
case "max":
689-
return Math.max(...values);
690-
}
691-
}
692-
693675
/**
694676
* Sort data array by a specified column
695677
*/
@@ -1032,6 +1014,7 @@ export const QueryResultsChart = memo(function QueryResultsChart({
10321014
labelFormatter={legendLabelFormatter}
10331015
showLegend={showLegend}
10341016
maxLegendItems={fullLegend ? Infinity : 5}
1017+
legendAggregation={config.aggregation}
10351018
minHeight="300px"
10361019
fillContainer
10371020
onViewAllLegendItems={onViewAllLegendItems}
@@ -1058,6 +1041,7 @@ export const QueryResultsChart = memo(function QueryResultsChart({
10581041
labelFormatter={legendLabelFormatter}
10591042
showLegend={showLegend}
10601043
maxLegendItems={fullLegend ? Infinity : 5}
1044+
legendAggregation={config.aggregation}
10611045
minHeight="300px"
10621046
fillContainer
10631047
onViewAllLegendItems={onViewAllLegendItems}

apps/webapp/app/components/primitives/charts/ChartLegendCompound.tsx

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,30 @@
11
import React, { useMemo } from "react";
2+
import type { AggregationType } from "~/components/metrics/QueryWidget";
23
import { useChartContext } from "./ChartContext";
34
import { useSeriesTotal } from "./ChartRoot";
5+
import { aggregateValues } from "./aggregation";
46
import { cn } from "~/utils/cn";
57
import { AnimatedNumber } from "../AnimatedNumber";
68

9+
const aggregationLabels: Record<AggregationType, string> = {
10+
sum: "Sum",
11+
avg: "Average",
12+
count: "Count",
13+
min: "Min",
14+
max: "Max",
15+
};
16+
717
export type ChartLegendCompoundProps = {
818
/** Maximum number of legend items to show before collapsing */
919
maxItems?: number;
1020
/** Hide the legend entirely (useful for conditional rendering) */
1121
hidden?: boolean;
1222
/** Additional className */
1323
className?: string;
14-
/** Label for the total row */
24+
/** Label for the total row (derived from aggregation when not provided) */
1525
totalLabel?: string;
26+
/** Aggregation method – controls the header label and how totals are computed */
27+
aggregation?: AggregationType;
1628
/** Callback when "View all" button is clicked */
1729
onViewAllLegendItems?: () => void;
1830
/** When true, constrains legend to max 50% height with scrolling */
@@ -35,45 +47,59 @@ export function ChartLegendCompound({
3547
maxItems = Infinity,
3648
hidden = false,
3749
className,
38-
totalLabel = "Total",
50+
totalLabel,
51+
aggregation,
3952
onViewAllLegendItems,
4053
scrollable = false,
4154
}: ChartLegendCompoundProps) {
4255
const { config, dataKey, dataKeys, highlight, labelFormatter } = useChartContext();
43-
const totals = useSeriesTotal();
56+
const totals = useSeriesTotal(aggregation);
4457

45-
// Calculate grand total (sum of all series totals)
58+
// Derive the effective label from the aggregation type when no explicit label is provided
59+
const effectiveTotalLabel = totalLabel ?? (aggregation ? aggregationLabels[aggregation] : "Total");
60+
61+
// Calculate grand total by aggregating across all per-series values
4662
const grandTotal = useMemo(() => {
47-
return dataKeys.reduce((sum, key) => sum + (totals[key] || 0), 0);
48-
}, [totals, dataKeys]);
63+
const values = dataKeys.map((key) => totals[key] || 0);
64+
if (!aggregation) {
65+
// Default: sum
66+
return values.reduce((a, b) => a + b, 0);
67+
}
68+
return aggregateValues(values, aggregation);
69+
}, [totals, dataKeys, aggregation]);
4970

5071
// Calculate current total based on hover state
5172
const currentTotal = useMemo(() => {
5273
if (!highlight.activePayload?.length) return grandTotal;
5374

54-
// Sum all values from the hovered data point
55-
return highlight.activePayload.reduce((sum, item) => {
56-
if (item.value !== undefined && dataKeys.includes(item.dataKey as string)) {
57-
return sum + (Number(item.value) || 0);
58-
}
59-
return sum;
60-
}, 0);
61-
}, [highlight.activePayload, grandTotal, dataKeys]);
75+
// Collect all series values from the hovered data point
76+
const values = highlight.activePayload
77+
.filter((item) => item.value !== undefined && dataKeys.includes(item.dataKey as string))
78+
.map((item) => Number(item.value) || 0);
79+
80+
if (values.length === 0) return 0;
81+
82+
if (!aggregation) {
83+
// Default: sum
84+
return values.reduce((a, b) => a + b, 0);
85+
}
86+
return aggregateValues(values, aggregation);
87+
}, [highlight.activePayload, grandTotal, dataKeys, aggregation]);
6288

63-
// Get the label for the total row - x-axis value when hovering, totalLabel otherwise
89+
// Get the label for the total row - x-axis value when hovering, effectiveTotalLabel otherwise
6490
const currentTotalLabel = useMemo(() => {
65-
if (!highlight.activePayload?.length) return totalLabel;
91+
if (!highlight.activePayload?.length) return effectiveTotalLabel;
6692

6793
// Get the x-axis label from the payload's original data
6894
const firstPayloadItem = highlight.activePayload[0];
6995
const xAxisValue = firstPayloadItem?.payload?.[dataKey];
7096

71-
if (xAxisValue === undefined) return totalLabel;
97+
if (xAxisValue === undefined) return effectiveTotalLabel;
7298

7399
// Apply the formatter if provided, otherwise just stringify the value
74100
const stringValue = String(xAxisValue);
75101
return labelFormatter ? labelFormatter(stringValue) : stringValue;
76-
}, [highlight.activePayload, dataKey, totalLabel, labelFormatter]);
102+
}, [highlight.activePayload, dataKey, effectiveTotalLabel, labelFormatter]);
77103

78104
// Get current data for the legend based on hover state
79105
const currentData = useMemo(() => {

apps/webapp/app/components/primitives/charts/ChartRoot.tsx

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import React, { useMemo } from "react";
22
import type * as RechartsPrimitive from "recharts";
3+
import type { AggregationType } from "~/components/metrics/QueryWidget";
34
import { ChartContainer, type ChartConfig, type ChartState } from "./Chart";
45
import { ChartProvider, useChartContext, type LabelFormatter } from "./ChartContext";
56
import { ChartLegendCompound } from "./ChartLegendCompound";
@@ -29,6 +30,8 @@ export type ChartRootProps = {
2930
maxLegendItems?: number;
3031
/** Label for the total row in the legend */
3132
legendTotalLabel?: string;
33+
/** Aggregation method used by the legend to compute totals (defaults to sum behavior) */
34+
legendAggregation?: AggregationType;
3235
/** Callback when "View all" legend button is clicked */
3336
onViewAllLegendItems?: () => void;
3437
/** When true, constrains legend to max 50% height with scrolling */
@@ -73,6 +76,7 @@ export function ChartRoot({
7376
showLegend = false,
7477
maxLegendItems = 5,
7578
legendTotalLabel,
79+
legendAggregation,
7680
onViewAllLegendItems,
7781
legendScrollable = false,
7882
fillContainer = false,
@@ -96,6 +100,7 @@ export function ChartRoot({
96100
showLegend={showLegend}
97101
maxLegendItems={maxLegendItems}
98102
legendTotalLabel={legendTotalLabel}
103+
legendAggregation={legendAggregation}
99104
onViewAllLegendItems={onViewAllLegendItems}
100105
legendScrollable={legendScrollable}
101106
fillContainer={fillContainer}
@@ -112,6 +117,7 @@ type ChartRootInnerProps = {
112117
showLegend?: boolean;
113118
maxLegendItems?: number;
114119
legendTotalLabel?: string;
120+
legendAggregation?: AggregationType;
115121
onViewAllLegendItems?: () => void;
116122
legendScrollable?: boolean;
117123
fillContainer?: boolean;
@@ -124,6 +130,7 @@ function ChartRootInner({
124130
showLegend = false,
125131
maxLegendItems = 5,
126132
legendTotalLabel,
133+
legendAggregation,
127134
onViewAllLegendItems,
128135
legendScrollable = false,
129136
fillContainer = false,
@@ -165,6 +172,7 @@ function ChartRootInner({
165172
<ChartLegendCompound
166173
maxItems={maxLegendItems}
167174
totalLabel={legendTotalLabel}
175+
aggregation={legendAggregation}
168176
onViewAllLegendItems={onViewAllLegendItems}
169177
scrollable={legendScrollable}
170178
/>
@@ -194,18 +202,75 @@ export function useHasNoData(): boolean {
194202
}
195203

196204
/**
197-
* Hook to calculate totals for each series across all data points.
205+
* Hook to calculate aggregated values for each series across all data points.
206+
* When no aggregation is provided, defaults to sum (original behavior).
198207
* Useful for legend displays.
199208
*/
200-
export function useSeriesTotal(): Record<string, number> {
209+
export function useSeriesTotal(aggregation?: AggregationType): Record<string, number> {
201210
const { data, dataKeys } = useChartContext();
202211

203212
return useMemo(() => {
204-
return data.reduce((acc, item) => {
213+
// Sum (default) and count use additive accumulation
214+
if (!aggregation || aggregation === "sum" || aggregation === "count") {
215+
return data.reduce(
216+
(acc, item) => {
217+
for (const seriesKey of dataKeys) {
218+
acc[seriesKey] = (acc[seriesKey] || 0) + Number(item[seriesKey] || 0);
219+
}
220+
return acc;
221+
},
222+
{} as Record<string, number>
223+
);
224+
}
225+
226+
if (aggregation === "avg") {
227+
const sums: Record<string, number> = {};
228+
const counts: Record<string, number> = {};
229+
for (const item of data) {
230+
for (const seriesKey of dataKeys) {
231+
const val = Number(item[seriesKey] || 0);
232+
sums[seriesKey] = (sums[seriesKey] || 0) + val;
233+
counts[seriesKey] = (counts[seriesKey] || 0) + 1;
234+
}
235+
}
236+
const result: Record<string, number> = {};
237+
for (const key of dataKeys) {
238+
result[key] = counts[key] ? sums[key]! / counts[key]! : 0;
239+
}
240+
return result;
241+
}
242+
243+
if (aggregation === "min") {
244+
const result: Record<string, number> = {};
245+
for (const item of data) {
246+
for (const seriesKey of dataKeys) {
247+
const val = Number(item[seriesKey] || 0);
248+
if (result[seriesKey] === undefined || val < result[seriesKey]) {
249+
result[seriesKey] = val;
250+
}
251+
}
252+
}
253+
// Default to 0 for series with no data
254+
for (const key of dataKeys) {
255+
if (result[key] === undefined) result[key] = 0;
256+
}
257+
return result;
258+
}
259+
260+
// aggregation === "max"
261+
const result: Record<string, number> = {};
262+
for (const item of data) {
205263
for (const seriesKey of dataKeys) {
206-
acc[seriesKey] = (acc[seriesKey] || 0) + Number(item[seriesKey] || 0);
264+
const val = Number(item[seriesKey] || 0);
265+
if (result[seriesKey] === undefined || val > result[seriesKey]) {
266+
result[seriesKey] = val;
267+
}
207268
}
208-
return acc;
209-
}, {} as Record<string, number>);
210-
}, [data, dataKeys]);
269+
}
270+
// Default to 0 for series with no data
271+
for (const key of dataKeys) {
272+
if (result[key] === undefined) result[key] = 0;
273+
}
274+
return result;
275+
}, [data, dataKeys, aggregation]);
211276
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import type { AggregationType } from "~/components/metrics/QueryWidget";
2+
3+
/**
4+
* Aggregate an array of numbers using the specified aggregation function.
5+
*
6+
* Shared utility so both QueryResultsChart (data transformation) and chart
7+
* legend components can reuse the same logic without circular imports.
8+
*/
9+
export function aggregateValues(values: number[], aggregation: AggregationType): number {
10+
if (values.length === 0) return 0;
11+
switch (aggregation) {
12+
case "sum":
13+
return values.reduce((a, b) => a + b, 0);
14+
case "avg":
15+
return values.reduce((a, b) => a + b, 0) / values.length;
16+
case "count":
17+
return values.length;
18+
case "min":
19+
return Math.min(...values);
20+
case "max":
21+
return Math.max(...values);
22+
}
23+
}

0 commit comments

Comments
 (0)