Skip to content

Commit c393a01

Browse files
authored
Merge pull request #3161 from perspective-dev/weighted-mean-bug
Fix multi-column aggregates by expressions
2 parents 7e6d181 + 961bdd7 commit c393a01

3 files changed

Lines changed: 67 additions & 11 deletions

File tree

rust/perspective-js/test/js/aggregates.spec.js

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,25 @@ const std = (nums) => {
221221
table.delete();
222222
});
223223

224+
test("['z'], weighted mean by expression", async function () {
225+
var table = await perspective.table(data2);
226+
var view = await table.view({
227+
group_by: ["z"],
228+
aggregates: { x: ["weighted mean", ["q"]] },
229+
expressions: { q: `"y"` },
230+
columns: ["x"],
231+
});
232+
var answer = [
233+
{ __ROW_PATH__: [], x: 2.8333333333333335 },
234+
{ __ROW_PATH__: [false], x: 3.3333333333333335 },
235+
{ __ROW_PATH__: [true], x: 2.3333333333333335 },
236+
];
237+
let result = await view.to_json();
238+
expect(result).toEqual(answer);
239+
view.delete();
240+
table.delete();
241+
});
242+
224243
test("['z'], weighted mean on a table created from schema should return valid values after update", async function () {
225244
const table = await perspective.table({
226245
x: "integer",

rust/perspective-python/perspective/tests/table/test_view.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,44 @@ def test_view_aggregate_weighted_mean(self):
455455
{"__ROW_PATH__": ["a"], "y": (1.0 * 200 + 2 * 100) / (1.0 + 2)},
456456
]
457457

458+
def test_view_aggregate_weighted_mean_by_expression(self):
459+
data = [
460+
{"a": "a", "x": 1, "y": 200},
461+
{"a": "a", "x": 2, "y": 100},
462+
{"a": "a", "x": 3, "y": None},
463+
]
464+
tbl = Table(data)
465+
view = tbl.view(
466+
aggregates={"y": ("weighted mean", ["z"])},
467+
group_by=["a"],
468+
columns=["y", "z"],
469+
expressions={"z": '"x"'},
470+
)
471+
472+
assert view.to_records() == [
473+
{"__ROW_PATH__": [], "y": (1.0 * 200 + 2 * 100) / (1.0 + 2), "z": 6},
474+
{"__ROW_PATH__": ["a"], "y": (1.0 * 200 + 2 * 100) / (1.0 + 2), "z": 6},
475+
]
476+
477+
def test_view_aggregate_weighted_mean_by_expression_without_column_ref(self):
478+
data = [
479+
{"a": "a", "x": 1, "y": 200},
480+
{"a": "a", "x": 2, "y": 100},
481+
{"a": "a", "x": 3, "y": None},
482+
]
483+
tbl = Table(data)
484+
view = tbl.view(
485+
aggregates={"y": ("weighted mean", ["z"])},
486+
group_by=["a"],
487+
columns=["y"],
488+
expressions={"z": '"x" + 1'},
489+
)
490+
491+
assert view.to_records() == [
492+
{"__ROW_PATH__": [], "y": (2 * 200 + 3 * 100) / (2 + 3)},
493+
{"__ROW_PATH__": ["a"], "y": (2 * 200 + 3 * 100) / (2 + 3)},
494+
]
495+
458496
def test_view_aggregate_weighted_mean_with_negative_weights(self):
459497
data = [
460498
{"a": "a", "x": 1, "y": 200},

rust/perspective-server/cpp/perspective/src/cpp/view_config.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,16 @@ t_view_config::get_used_expressions() {
151151
std::inserter(used_cols, used_cols.end())
152152
);
153153

154-
std::copy(
155-
m_row_pivots.begin(),
156-
m_row_pivots.end(),
157-
std::inserter(used_cols, used_cols.end())
158-
);
159-
160-
std::copy(
161-
m_row_pivots.begin(),
162-
m_row_pivots.end(),
163-
std::inserter(used_cols, used_cols.end())
164-
);
154+
for (const auto& agg : m_aggregates) {
155+
if (std::find(m_columns.begin(), m_columns.end(), agg.first)
156+
== m_columns.end()) {
157+
continue;
158+
}
159+
const auto& aggregate = agg.second;
160+
for (std::size_t i = 1; i < aggregate.size(); ++i) {
161+
used_cols.insert(aggregate[i]);
162+
}
163+
}
165164

166165
auto iter = std::remove_if(
167166
exprs.begin(),

0 commit comments

Comments
 (0)