Skip to content

Commit e700cbb

Browse files
committed
fixes based on review
1 parent 156d5c2 commit e700cbb

5 files changed

Lines changed: 214 additions & 25 deletions

File tree

core/queryalgebra/evaluation/src/main/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIterator.java

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;
6868
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateCollector;
6969
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunction;
70+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunctionFactory;
71+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory;
7072
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateProcessor;
7173
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry;
7274
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry;
@@ -465,29 +467,22 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
465467
ge.getName());
466468
} else if (operator instanceof AggregateFunctionCall) {
467469
var aggOperator = (AggregateFunctionCall) operator;
470+
int argumentCount = aggOperator.getArguments().size();
471+
var nAryFactory = CustomAggregateNAryFunctionRegistry.getInstance().get(aggOperator.getIRI());
472+
var unaryFactory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI());
468473

469-
if (aggOperator.getArguments().size() > 1) {
470-
Supplier<Predicate<List<Value>>> predicate = createDistinctTupleValueTest(operator);
471-
var factory = CustomAggregateNAryFunctionRegistry.getInstance().get(aggOperator.getIRI());
472-
473-
var function = factory.orElseThrow(
474-
() -> new QueryEvaluationException(
475-
"Unknown n-ary aggregate function '" + aggOperator.getIRI() + "'"))
476-
.buildFunction(precompileNAryArg(aggOperator));
477-
return new AggregatePredicateCollectorSupplier<>(function, predicate,
478-
() -> factory.get().getCollector(),
479-
ge.getName());
480-
} else {
481-
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
482-
var factory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI());
483-
484-
var function = factory.orElseThrow(
485-
() -> new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'"))
486-
.buildFunction(precompileNAryArg(aggOperator).asUnaryEvaluator(0));
487-
return new AggregatePredicateCollectorSupplier<>(function, predicate,
488-
() -> factory.get().getCollector(),
489-
ge.getName());
474+
if (argumentCount == 1 && unaryFactory.isPresent()) {
475+
return createUnaryCustomAggregate(ge, operator, aggOperator, unaryFactory.get());
476+
}
477+
if (nAryFactory.isPresent()) {
478+
validateNAryAggregateArity(aggOperator, nAryFactory.get(), argumentCount);
479+
return createNAryCustomAggregate(ge, operator, aggOperator, nAryFactory.get());
490480
}
481+
if (unaryFactory.isPresent()) {
482+
return createUnaryCustomAggregate(ge, operator, aggOperator, unaryFactory.get());
483+
}
484+
485+
throw new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'");
491486
}
492487

493488
return null;
@@ -517,6 +512,43 @@ private Supplier<Predicate<List<Value>>> createDistinctTupleValueTest(AggregateO
517512
return operator.isDistinct() ? DistinctTupleValues::new : ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER;
518513
}
519514

515+
private AggregatePredicateCollectorSupplier<?, ?> createUnaryCustomAggregate(GroupElem ge,
516+
AggregateOperator operator,
517+
AggregateFunctionCall aggOperator, AggregateFunctionFactory factory) {
518+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
519+
AggregateFunction function = factory.buildFunction(precompileNAryArg(aggOperator).asUnaryEvaluator(0));
520+
return new AggregatePredicateCollectorSupplier<>(function, predicate, factory::getCollector, ge.getName());
521+
}
522+
523+
private AggregatePredicateCollectorSupplier<?, ?> createNAryCustomAggregate(GroupElem ge,
524+
AggregateOperator operator,
525+
AggregateFunctionCall aggOperator, AggregateNAryFunctionFactory factory) {
526+
Supplier<Predicate<List<Value>>> predicate = createDistinctTupleValueTest(operator);
527+
var function = factory.buildFunction(precompileNAryArg(aggOperator));
528+
return new AggregatePredicateCollectorSupplier<>(function, predicate, factory::getCollector, ge.getName());
529+
}
530+
531+
private void validateNAryAggregateArity(AggregateFunctionCall aggregateOperator,
532+
AggregateNAryFunctionFactory factory,
533+
int argumentCount) {
534+
int minimum = factory.getMinNumberOfArguments();
535+
int maximum = factory.getMaxNumberOfArguments();
536+
if (minimum < 0 || maximum < minimum) {
537+
throw new QueryEvaluationException("Custom n-ary aggregate function '" + aggregateOperator.getIRI()
538+
+ "' has invalid arity declaration");
539+
}
540+
if (argumentCount < minimum) {
541+
throw new QueryEvaluationException(
542+
"Custom n-ary aggregate function '" + aggregateOperator.getIRI() + "' expects at least " + minimum
543+
+ " arguments, got " + argumentCount);
544+
}
545+
if (argumentCount > maximum) {
546+
throw new QueryEvaluationException(
547+
"Custom n-ary aggregate function '" + aggregateOperator.getIRI() + "' expects at most " + maximum
548+
+ " arguments, got " + argumentCount);
549+
}
550+
}
551+
520552
private QueryStepEvaluator precompileUnaryArg(AggregateOperator operator) {
521553
ValueExpr ve = ((UnaryValueOperator) operator).getArg();
522554
return new QueryStepEvaluator(strategy.precompile(ve, context));

core/queryalgebra/evaluation/src/test/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIteratorTest.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,23 @@ public void testCustomNAryAggregateFunction_Empty() throws QueryEvaluationExcept
316316
}
317317
}
318318

319+
@Test
320+
public void testCustomNAryAggregateFunction_SingleArg_Nonempty() throws QueryEvaluationException {
321+
AggregateNAryFunctionFactory nAryFactory = new SingleArgAggregateNAryFunctionFactory();
322+
CustomAggregateNAryFunctionRegistry.getInstance().add(nAryFactory);
323+
try {
324+
Group group = new Group(NONEMPTY_ASSIGNMENT);
325+
group.addGroupElement(new GroupElem("narySingleArgSum",
326+
new AggregateFunctionCall(List.of(Var.of("a")), nAryFactory.getIri(), false)));
327+
try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) {
328+
assertThat(gi.next().getBinding("narySingleArgSum").getValue())
329+
.isEqualTo(VF.createLiteral("45", XSD.INTEGER));
330+
}
331+
} finally {
332+
CustomAggregateNAryFunctionRegistry.getInstance().remove(nAryFactory);
333+
}
334+
}
335+
319336
@Test
320337
public void testCustomNAryAggregateFunction_Empty_DistinctTuplePredicateInvoked() throws QueryEvaluationException {
321338
AggregateNAryFunctionFactory nAryFactory = new DistinctTupleTouchingAggregateNAryFunctionFactory();
@@ -522,6 +539,59 @@ public SumCollector getCollector() {
522539
}
523540
}
524541

542+
private static final class SingleArgAggregateNAryFunctionFactory implements AggregateNAryFunctionFactory {
543+
@Override
544+
public String getIri() {
545+
return "https://www.rdf4j.org/aggregate#nary-single-arg";
546+
}
547+
548+
@Override
549+
public int getMinNumberOfArguments() {
550+
return 1;
551+
}
552+
553+
@Override
554+
public int getMaxNumberOfArguments() {
555+
return 1;
556+
}
557+
558+
@Override
559+
public AggregateNAryFunction<SumCollector, Value> buildFunction(
560+
BiFunction<Integer, BindingSet, Value> evaluationStepByIndex) {
561+
return new AggregateNAryFunction<>(evaluationStepByIndex) {
562+
private ValueExprEvaluationException typeError = null;
563+
564+
@Override
565+
public void processAggregate(BindingSet bindingSet, Predicate<List<Value>> distinctValue,
566+
SumCollector sum)
567+
throws QueryEvaluationException {
568+
if (typeError != null) {
569+
return;
570+
}
571+
Value v = evaluate(0, bindingSet);
572+
if (v instanceof Literal) {
573+
if (distinctValue.test(List.of(v))) {
574+
Literal nextLiteral = (Literal) v;
575+
if (nextLiteral.getDatatype() != null
576+
&& XMLDatatypeUtil.isNumericDatatype(nextLiteral.getDatatype())) {
577+
sum.value = MathUtil.compute(sum.value, nextLiteral, MathExpr.MathOp.PLUS);
578+
} else {
579+
typeError = new ValueExprEvaluationException("not a number: " + v);
580+
}
581+
} else {
582+
typeError = new ValueExprEvaluationException("not a number: " + v);
583+
}
584+
}
585+
}
586+
};
587+
}
588+
589+
@Override
590+
public SumCollector getCollector() {
591+
return new SumCollector();
592+
}
593+
}
594+
525595
/**
526596
* Dummy collector to verify custom aggregate functions
527597
*/

core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*
99
* SPDX-License-Identifier: BSD-3-Clause
1010
*******************************************************************************/
11+
// Some portions generated by Codex
1112
package org.eclipse.rdf4j.query.parser.sparql;
1213

1314
import java.util.ArrayList;
@@ -105,6 +106,7 @@
105106
import org.eclipse.rdf4j.query.algebra.helpers.collectors.StatementPatternCollector;
106107
import org.eclipse.rdf4j.query.algebra.helpers.collectors.VarNameCollector;
107108
import org.eclipse.rdf4j.query.impl.ListBindingSet;
109+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory;
108110
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry;
109111
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry;
110112
import org.eclipse.rdf4j.query.parser.sparql.ast.ASTAbs;
@@ -1995,18 +1997,25 @@ public MathExpr visit(ASTMath node, Object data) throws VisitorException {
19951997
public Object visit(ASTFunctionCall node, Object data) throws VisitorException {
19961998
ValueConstant uriNode = (ValueConstant) node.jjtGetChild(0).jjtAccept(this, null);
19971999
IRI functionURI = (IRI) uriNode.getValue();
1998-
if (CustomAggregateNAryFunctionRegistry.getInstance().has(functionURI.stringValue())
1999-
|| CustomAggregateFunctionRegistry.getInstance().has(functionURI.stringValue()) || node.isDistinct()) {
2000-
AggregateFunctionCall aggregateCall = new AggregateFunctionCall(functionURI.stringValue(),
2000+
String functionIri = functionURI.stringValue();
2001+
var nAryFactory = CustomAggregateNAryFunctionRegistry.getInstance().get(functionIri);
2002+
boolean hasNAryFactory = nAryFactory.isPresent();
2003+
boolean hasUnaryFactory = CustomAggregateFunctionRegistry.getInstance().has(functionIri);
2004+
if (hasNAryFactory || hasUnaryFactory || node.isDistinct()) {
2005+
AggregateFunctionCall aggregateCall = new AggregateFunctionCall(functionIri,
20012006
node.isDistinct());
20022007
for (int i = 1; i < node.jjtGetNumChildren(); i++) {
20032008
Node argNode = node.jjtGetChild(i);
20042009
aggregateCall.addArgument(castToValueExpr(argNode.jjtAccept(this, null)));
20052010
}
20062011

2007-
if (aggregateCall.getArguments().isEmpty()) {
2012+
int argumentCount = aggregateCall.getArguments().size();
2013+
if (argumentCount == 0) {
20082014
throw new IllegalArgumentException("Aggregate function calls must have at least one argument");
20092015
}
2016+
if (hasNAryFactory && (argumentCount > 1 || !hasUnaryFactory)) {
2017+
validateNAryAggregateArity(functionIri, nAryFactory.get(), argumentCount);
2018+
}
20102019

20112020
return aggregateCall;
20122021
} else {
@@ -2022,6 +2031,24 @@ public Object visit(ASTFunctionCall node, Object data) throws VisitorException {
20222031
}
20232032
}
20242033

2034+
private void validateNAryAggregateArity(String functionIri, AggregateNAryFunctionFactory factory,
2035+
int argumentCount) {
2036+
int minimum = factory.getMinNumberOfArguments();
2037+
int maximum = factory.getMaxNumberOfArguments();
2038+
if (minimum < 0 || maximum < minimum) {
2039+
throw new IllegalArgumentException("Custom n-ary aggregate function '" + functionIri
2040+
+ "' has invalid arity declaration");
2041+
}
2042+
if (argumentCount < minimum) {
2043+
throw new IllegalArgumentException("Custom n-ary aggregate function calls must have at least " + minimum
2044+
+ " arguments");
2045+
}
2046+
if (argumentCount > maximum) {
2047+
throw new IllegalArgumentException("Custom n-ary aggregate function calls must have at most " + maximum
2048+
+ " arguments");
2049+
}
2050+
}
2051+
20252052
@Override
20262053
public FunctionCall visit(ASTEncodeForURI node, Object data) throws VisitorException {
20272054
return createFunctionCall(FN.ENCODE_FOR_URI.stringValue(), node, 1, 1);

core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateNAryFunctionFactory.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*
99
* SPDX-License-Identifier: BSD-3-Clause
1010
*******************************************************************************/
11+
// Some portions generated by Codex
1112
package org.eclipse.rdf4j.query.parser.sparql.aggregate;
1213

1314
import java.util.function.BiFunction;
@@ -29,6 +30,24 @@ public interface AggregateNAryFunctionFactory {
2930
*/
3031
String getIri();
3132

33+
/**
34+
* Lower bound for accepted number of aggregate arguments.
35+
*
36+
* @return minimum accepted number of arguments
37+
*/
38+
default int getMinNumberOfArguments() {
39+
return 2;
40+
}
41+
42+
/**
43+
* Upper bound for accepted number of aggregate arguments.
44+
*
45+
* @return maximum accepted number of arguments
46+
*/
47+
default int getMaxNumberOfArguments() {
48+
return Integer.MAX_VALUE;
49+
}
50+
3251
/**
3352
* Builds an aggregate function with input evaluation step
3453
*

core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/SPARQLParserTest.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*
99
* SPDX-License-Identifier: BSD-3-Clause
1010
*******************************************************************************/
11+
// Some portions generated by Codex
1112
package org.eclipse.rdf4j.query.parser.sparql;
1213

1314
import static org.assertj.core.api.Assertions.assertThat;
@@ -33,6 +34,7 @@
3334
import java.util.Iterator;
3435
import java.util.List;
3536
import java.util.Set;
37+
import java.util.function.BiFunction;
3638
import java.util.function.Function;
3739

3840
import org.eclipse.rdf4j.model.Model;
@@ -73,7 +75,10 @@
7375
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateCollector;
7476
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunction;
7577
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunctionFactory;
78+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunction;
79+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory;
7680
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry;
81+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry;
7782
import org.eclipse.rdf4j.rio.RDFFormat;
7883
import org.eclipse.rdf4j.rio.Rio;
7984
import org.junit.jupiter.api.AfterEach;
@@ -639,6 +644,23 @@ public void testProjectionHandling_FunctionCallWithoutArgsFails() {
639644
}
640645
}
641646

647+
@Test
648+
public void testProjectionHandling_NAryFunctionCallWithTooFewArgsFails() {
649+
var factory = buildDummyNAryFactory();
650+
String query = "prefix rj: <https://www.rdf4j.org/aggregate#>"
651+
+ "SELECT (rj:nary(?o) AS ?o1) \n"
652+
+ "WHERE {\n"
653+
+ " ?s ?p ?o \n"
654+
+ "} GROUP BY ?s ?o";
655+
try {
656+
CustomAggregateNAryFunctionRegistry.getInstance().add(factory);
657+
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> parser.parseQuery(query, null))
658+
.withMessageStartingWith("Custom n-ary aggregate function calls must have at least 2 arguments");
659+
} finally {
660+
CustomAggregateNAryFunctionRegistry.getInstance().remove(factory);
661+
}
662+
}
663+
642664
@Test
643665
public void testGroupByProjectionHandling_Aggregate_NonSimpleExpr() {
644666
String query = "SELECT (COUNT(?s) as ?count) (?o + ?s AS ?o1) \n"
@@ -1198,6 +1220,25 @@ public AggregateCollector getCollector() {
11981220
};
11991221
}
12001222

1223+
private AggregateNAryFunctionFactory buildDummyNAryFactory() {
1224+
return new AggregateNAryFunctionFactory() {
1225+
@Override
1226+
public String getIri() {
1227+
return "https://www.rdf4j.org/aggregate#nary";
1228+
}
1229+
1230+
@Override
1231+
public AggregateNAryFunction buildFunction(BiFunction<Integer, BindingSet, Value> evaluationStepByIndex) {
1232+
return null;
1233+
}
1234+
1235+
@Override
1236+
public AggregateCollector getCollector() {
1237+
return null;
1238+
}
1239+
};
1240+
}
1241+
12011242
private void verifySerializable(QueryModelNode tupleExpr) {
12021243
byte[] bytes = objectToBytes(tupleExpr);
12031244
QueryModelNode parsed = (QueryModelNode) bytesToObject(bytes);

0 commit comments

Comments
 (0)