diff --git a/core/queryalgebra/evaluation/src/main/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIterator.java b/core/queryalgebra/evaluation/src/main/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIterator.java index d698e83d9cc..8d83907246b 100644 --- a/core/queryalgebra/evaluation/src/main/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIterator.java +++ b/core/queryalgebra/evaluation/src/main/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIterator.java @@ -8,6 +8,7 @@ * * SPDX-License-Identifier: BSD-3-Clause *******************************************************************************/ +// Some portions generated by Codex package org.eclipse.rdf4j.query.algebra.evaluation.iterator; import java.util.ArrayList; @@ -19,6 +20,7 @@ import java.util.Random; import java.util.Set; import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; @@ -50,6 +52,7 @@ import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp; import org.eclipse.rdf4j.query.algebra.Max; import org.eclipse.rdf4j.query.algebra.Min; +import org.eclipse.rdf4j.query.algebra.NAryValueOperator; import org.eclipse.rdf4j.query.algebra.Sample; import org.eclipse.rdf4j.query.algebra.Sum; import org.eclipse.rdf4j.query.algebra.UnaryValueOperator; @@ -65,7 +68,11 @@ import org.eclipse.rdf4j.query.impl.EmptyBindingSet; import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateCollector; import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunction; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunctionFactory; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateProcessor; import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry; /** * @author David Huynh @@ -349,7 +356,7 @@ private List emptySolutionSpecialCase(List { public final String name; - private final AggregateFunction agg; + private final AggregateProcessor agg; private final Supplier> makePotentialDistinctTest; private final Supplier makeAggregateCollector; - public AggregatePredicateCollectorSupplier(AggregateFunction agg, + public AggregatePredicateCollectorSupplier(AggregateProcessor agg, Supplier> makePotentialDistinctTest, Supplier makeAggregateCollector, String name) { super(); this.agg = agg; @@ -431,6 +438,7 @@ private void operate(BindingSet bs, Predicate predicate, Object t) { private static final Predicate ALWAYS_TRUE_VALUE = t -> true; private static final Predicate ALWAYS_FALSE_VALUE = t -> false; private static final Supplier> ALWAYS_TRUE_VALUE_SUPPLIER = () -> ALWAYS_TRUE_VALUE; + private static final Supplier>> ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER = () -> t -> true; private AggregatePredicateCollectorSupplier create(GroupElem ge, ValueFactory vf) throws QueryEvaluationException { @@ -444,57 +452,68 @@ private void operate(BindingSet bs, Predicate predicate, Object t) { return new AggregatePredicateCollectorSupplier<>(wildCardCountAggregate, potentialDistinctTest, () -> new CountCollector(vf), ge.getName()); } else { - QueryStepEvaluator f = precompileArg(operator); + QueryStepEvaluator f = precompileUnaryArg(operator); CountAggregate agg = new CountAggregate(f); - Supplier> predicate = createDistinctValueTest(operator); + Supplier> predicate = createDistinctSingleValueTest(operator); return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new CountCollector(vf), ge.getName()); } } else if (operator instanceof Min) { - MinAggregate agg = new MinAggregate(precompileArg(operator), shouldValueComparisonBeStrict()); - Supplier> predicate = createDistinctValueTest(operator); + MinAggregate agg = new MinAggregate(precompileUnaryArg(operator), shouldValueComparisonBeStrict()); + Supplier> predicate = createDistinctSingleValueTest(operator); return new AggregatePredicateCollectorSupplier<>(agg, predicate, ValueCollector::new, ge.getName()); } else if (operator instanceof Max) { - MaxAggregate agg = new MaxAggregate(precompileArg(operator), shouldValueComparisonBeStrict()); - Supplier> predicate = createDistinctValueTest(operator); + MaxAggregate agg = new MaxAggregate(precompileUnaryArg(operator), shouldValueComparisonBeStrict()); + Supplier> predicate = createDistinctSingleValueTest(operator); return new AggregatePredicateCollectorSupplier<>(agg, predicate, ValueCollector::new, ge.getName()); } else if (operator instanceof Sum) { - SumAggregate agg = new SumAggregate(precompileArg(operator)); - Supplier> predicate = createDistinctValueTest(operator); + SumAggregate agg = new SumAggregate(precompileUnaryArg(operator)); + Supplier> predicate = createDistinctSingleValueTest(operator); return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new IntegerCollector(vf), ge.getName()); } else if (operator instanceof Avg) { - AvgAggregate agg = new AvgAggregate(precompileArg(operator)); - Supplier> predicate = createDistinctValueTest(operator); + AvgAggregate agg = new AvgAggregate(precompileUnaryArg(operator)); + Supplier> predicate = createDistinctSingleValueTest(operator); return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new AvgCollector(vf), ge.getName()); } else if (operator instanceof Sample) { - SampleAggregate agg = new SampleAggregate(precompileArg(operator)); - Supplier> predicate = createDistinctValueTest(operator); + SampleAggregate agg = new SampleAggregate(precompileUnaryArg(operator)); + Supplier> predicate = createDistinctSingleValueTest(operator); return new AggregatePredicateCollectorSupplier<>(agg, predicate, SampleCollector::new, ge.getName()); } else if (operator instanceof GroupConcat) { ValueExpr separatorExpr = ((GroupConcat) operator).getSeparator(); ConcatAggregate agg; if (separatorExpr != null) { Value separatorValue = strategy.evaluate(separatorExpr, parentBindings); - agg = new ConcatAggregate(precompileArg(operator), separatorValue.stringValue()); + agg = new ConcatAggregate(precompileUnaryArg(operator), separatorValue.stringValue()); } else { - agg = new ConcatAggregate(precompileArg(operator)); + agg = new ConcatAggregate(precompileUnaryArg(operator)); } - Supplier> predicate = createDistinctValueTest(operator); + Supplier> predicate = createDistinctSingleValueTest(operator); return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new StringBuilderCollector(vf), ge.getName()); } else if (operator instanceof AggregateFunctionCall) { var aggOperator = (AggregateFunctionCall) operator; - Supplier> predicate = createDistinctValueTest(operator); - var factory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI()); + int argumentCount = aggOperator.getArguments().size(); + var nAryFactory = CustomAggregateNAryFunctionRegistry.getInstance().get(aggOperator.getIRI()); + var unaryFactory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI()); - var function = factory.orElseThrow( - () -> new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'")) - .buildFunction(precompileArg(aggOperator)); - return new AggregatePredicateCollectorSupplier<>(function, predicate, () -> factory.get().getCollector(), - ge.getName()); + if (argumentCount == 1 && unaryFactory.isPresent()) { + return createUnaryCustomAggregate(ge, operator, aggOperator, unaryFactory.get()); + } + if (nAryFactory.isPresent()) { + validateNAryAggregateArity(aggOperator, nAryFactory.get(), argumentCount); + return createNAryCustomAggregate(ge, operator, aggOperator, nAryFactory.get()); + } + if (unaryFactory.isPresent()) { + if (argumentCount != 1) { + throw new QueryEvaluationException("Custom unary aggregate function '" + aggOperator.getIRI() + + "' expects exactly 1 argument, got " + argumentCount); + } + return createUnaryCustomAggregate(ge, operator, aggOperator, unaryFactory.get()); + } + throw new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'"); } return null; @@ -508,15 +527,73 @@ private void operate(BindingSet bs, Predicate predicate, Object t) { * @return a supplier that returns a predicate that tests if the value is distinct, or always returns true if the * operator is not distinct. */ - private Supplier> createDistinctValueTest(AggregateOperator operator) { + private Supplier> createDistinctSingleValueTest(AggregateOperator operator) { return operator.isDistinct() ? DistinctValues::new : ALWAYS_TRUE_VALUE_SUPPLIER; } - private QueryStepEvaluator precompileArg(AggregateOperator operator) { + /** + * Create a predicate that tests if the tuple of values is distinct (returning true if the tuple was not seen yet), + * or always returns true if the operator is not distinct. + * + * @param operator + * @return a supplier that returns a predicate that tests if the tuple of values is distinct, or always returns true + * if the operator is not distinct. + */ + private Supplier>> createDistinctTupleValueTest(AggregateOperator operator) { + return operator.isDistinct() ? DistinctTupleValues::new : ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER; + } + + private AggregatePredicateCollectorSupplier createUnaryCustomAggregate(GroupElem ge, + AggregateOperator operator, + AggregateFunctionCall aggOperator, AggregateFunctionFactory factory) { + Supplier> predicate = createDistinctSingleValueTest(operator); + AggregateFunction function = factory.buildFunction(precompileNAryArg(aggOperator).asUnaryEvaluator(0)); + return new AggregatePredicateCollectorSupplier<>(function, predicate, factory::getCollector, ge.getName()); + } + + private AggregatePredicateCollectorSupplier createNAryCustomAggregate(GroupElem ge, + AggregateOperator operator, + AggregateFunctionCall aggOperator, AggregateNAryFunctionFactory factory) { + Supplier>> predicate = createDistinctTupleValueTest(operator); + var function = factory.buildFunction(precompileNAryArg(aggOperator)); + return new AggregatePredicateCollectorSupplier<>(function, predicate, factory::getCollector, ge.getName()); + } + + private void validateNAryAggregateArity(AggregateFunctionCall aggregateOperator, + AggregateNAryFunctionFactory factory, + int argumentCount) { + int minimum = factory.getMinNumberOfArguments(); + int maximum = factory.getMaxNumberOfArguments(); + if (minimum < 0 || maximum < minimum) { + throw new QueryEvaluationException("Custom n-ary aggregate function '" + aggregateOperator.getIRI() + + "' has invalid arity declaration"); + } + if (argumentCount < minimum) { + throw new QueryEvaluationException( + "Custom n-ary aggregate function '" + aggregateOperator.getIRI() + "' expects at least " + minimum + + " arguments, got " + argumentCount); + } + if (argumentCount > maximum) { + throw new QueryEvaluationException( + "Custom n-ary aggregate function '" + aggregateOperator.getIRI() + "' expects at most " + maximum + + " arguments, got " + argumentCount); + } + } + + private QueryStepEvaluator precompileUnaryArg(AggregateOperator operator) { ValueExpr ve = ((UnaryValueOperator) operator).getArg(); return new QueryStepEvaluator(strategy.precompile(ve, context)); } + private NAryQueryStepEvaluator precompileNAryArg(AggregateOperator operator) { + List args = ((NAryValueOperator) operator).getArguments(); + List precompiledArgs = new ArrayList<>(args.size()); + for (ValueExpr arg : args) { + precompiledArgs.add(strategy.precompile(arg, context)); + } + return new NAryQueryStepEvaluator(precompiledArgs::get); + } + private boolean shouldValueComparisonBeStrict() { return strategy.getQueryEvaluationMode() == QueryEvaluationMode.STRICT; } @@ -623,6 +700,19 @@ public boolean test(Value value) { } } + private class DistinctTupleValues implements Predicate> { + private final Set> distinctTuples; + + public DistinctTupleValues() { + distinctTuples = cf.createSet(); + } + + @Override + public boolean test(List valueTuple) { + return distinctTuples.add(valueTuple); + } + } + private class DistinctBindingSets implements Predicate { private final Set distinctValues; @@ -884,4 +974,25 @@ public Value apply(BindingSet bindings) { } } } + + private static class NAryQueryStepEvaluator implements BiFunction { + private final Function evaluationStepFunction; + + public NAryQueryStepEvaluator(Function evaluationStepFunction) { + this.evaluationStepFunction = evaluationStepFunction; + } + + @Override + public Value apply(Integer index, BindingSet bindings) { + try { + return evaluationStepFunction.apply(index).evaluate(bindings); + } catch (ValueExprEvaluationException e) { + return null; // treat missing or invalid expressions as null + } + } + + public QueryStepEvaluator asUnaryEvaluator(Integer index) { + return new QueryStepEvaluator(evaluationStepFunction.apply(index)); + } + } } diff --git a/core/queryalgebra/evaluation/src/test/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIteratorTest.java b/core/queryalgebra/evaluation/src/test/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIteratorTest.java index 0e35107c914..9010f061938 100644 --- a/core/queryalgebra/evaluation/src/test/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIteratorTest.java +++ b/core/queryalgebra/evaluation/src/test/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIteratorTest.java @@ -8,6 +8,7 @@ * * SPDX-License-Identifier: BSD-3-Clause *******************************************************************************/ +// Some portions generated by Codex package org.eclipse.rdf4j.query.algebra.evaluation.iterator; import static org.assertj.core.api.Assertions.assertThat; @@ -20,10 +21,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Date; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; @@ -63,7 +66,10 @@ import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateCollector; import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunction; import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunctionFactory; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunction; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory; import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -241,7 +247,7 @@ public void testSumNotZero() throws QueryEvaluationException { public void testCustomAggregateFunction_Nonempty() throws QueryEvaluationException { Group group = new Group(NONEMPTY_ASSIGNMENT); group.addGroupElement(new GroupElem("customSum", - new AggregateFunctionCall(Var.of("a"), AGGREGATE_FUNCTION_FACTORY.getIri(), false))); + new AggregateFunctionCall(List.of(Var.of("a")), AGGREGATE_FUNCTION_FACTORY.getIri(), false))); try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { assertThat(gi.next().getBinding("customSum").getValue()).isEqualTo(VF.createLiteral("45", XSD.INTEGER)); } @@ -251,7 +257,7 @@ public void testCustomAggregateFunction_Nonempty() throws QueryEvaluationExcepti public void testCustomAggregateFunction_Empty() throws QueryEvaluationException { Group group = new Group(EMPTY_ASSIGNMENT); group.addGroupElement(new GroupElem("customSum", - new AggregateFunctionCall(Var.of("a"), AGGREGATE_FUNCTION_FACTORY.getIri(), false))); + new AggregateFunctionCall(List.of(Var.of("a")), AGGREGATE_FUNCTION_FACTORY.getIri(), false))); try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { assertThat(gi.next().getBinding("customSum").getValue()).isEqualTo(VF.createLiteral("0", XSD.INTEGER)); } @@ -260,13 +266,124 @@ public void testCustomAggregateFunction_Empty() throws QueryEvaluationException @Test public void testCustomAggregateFunction_WrongIri() throws QueryEvaluationException { Group group = new Group(EMPTY_ASSIGNMENT); - group.addGroupElement(new GroupElem("customSum", new AggregateFunctionCall(Var.of("a"), "urn:i", false))); + group.addGroupElement( + new GroupElem("customSum", new AggregateFunctionCall(List.of(Var.of("a")), "urn:i", false))); try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { assertThatExceptionOfType(QueryEvaluationException.class) .isThrownBy(() -> gi.next().getBinding("customSum").getValue()); } } + @Test + public void testCustomAggregateFunction_MultipleArgsRejected() throws QueryEvaluationException { + BindingSetAssignment assignment = new BindingSetAssignment(); + var list = new ArrayList(); + for (int i = 1; i < 10; i++) { + var bindings = new QueryBindingSet(); + bindings.addBinding("a", VF.createLiteral(i)); + bindings.addBinding("b", VF.createLiteral(i * 2)); + list.add(bindings); + } + assignment.setBindingSets(Collections.unmodifiableList(list)); + + Group group = new Group(assignment); + group.addGroupElement(new GroupElem("customSum", + new AggregateFunctionCall(List.of(Var.of("a"), Var.of("b")), AGGREGATE_FUNCTION_FACTORY.getIri(), + false))); + try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { + assertThatExceptionOfType(QueryEvaluationException.class) + .isThrownBy(() -> gi.next().getBinding("customSum").getValue()) + .withMessageContaining("expects exactly 1 argument"); + } + } + + @Test + public void testCustomNAryAggregateFunction_Nonempty() throws QueryEvaluationException { + AggregateNAryFunctionFactory nAryFactory = new FakeAggregateNAryFunctionFactory(); + CustomAggregateNAryFunctionRegistry.getInstance().add(nAryFactory); + try { + BindingSetAssignment assignment = new BindingSetAssignment(); + var list = new ArrayList(); + for (int i = 1; i < 10; i++) { + var bindings = new QueryBindingSet(); + bindings.addBinding("a", VF.createLiteral(i)); + bindings.addBinding("b", VF.createLiteral(i * 2)); + list.add(bindings); + } + assignment.setBindingSets(Collections.unmodifiableList(list)); + + Group group = new Group(assignment); + group.addGroupElement(new GroupElem("narySum", + new AggregateFunctionCall(List.of(Var.of("a"), Var.of("b")), nAryFactory.getIri(), false))); + try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { + assertThat(gi.next().getBinding("narySum").getValue()).isEqualTo(VF.createLiteral("135", XSD.INTEGER)); + } + } finally { + CustomAggregateNAryFunctionRegistry.getInstance().remove(nAryFactory); + } + } + + @Test + public void testCustomNAryAggregateFunction_Empty() throws QueryEvaluationException { + AggregateNAryFunctionFactory nAryFactory = new FakeAggregateNAryFunctionFactory(); + CustomAggregateNAryFunctionRegistry.getInstance().add(nAryFactory); + try { + Group group = new Group(EMPTY_ASSIGNMENT); + group.addGroupElement(new GroupElem("narySum", + new AggregateFunctionCall(List.of(Var.of("a"), Var.of("b")), nAryFactory.getIri(), false))); + try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { + assertThat(gi.next().getBinding("narySum").getValue()).isEqualTo(VF.createLiteral("0", XSD.INTEGER)); + } + } finally { + CustomAggregateNAryFunctionRegistry.getInstance().remove(nAryFactory); + } + } + + @Test + public void testCustomNAryAggregateFunction_SingleArg_Nonempty() throws QueryEvaluationException { + AggregateNAryFunctionFactory nAryFactory = new SingleArgAggregateNAryFunctionFactory(); + CustomAggregateNAryFunctionRegistry.getInstance().add(nAryFactory); + try { + Group group = new Group(NONEMPTY_ASSIGNMENT); + group.addGroupElement(new GroupElem("narySingleArgSum", + new AggregateFunctionCall(List.of(Var.of("a")), nAryFactory.getIri(), false))); + try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { + assertThat(gi.next().getBinding("narySingleArgSum").getValue()) + .isEqualTo(VF.createLiteral("45", XSD.INTEGER)); + } + } finally { + CustomAggregateNAryFunctionRegistry.getInstance().remove(nAryFactory); + } + } + + @Test + public void testCustomNAryAggregateFunction_Empty_DistinctTuplePredicateInvoked() throws QueryEvaluationException { + AggregateNAryFunctionFactory nAryFactory = new DistinctTupleTouchingAggregateNAryFunctionFactory(); + CustomAggregateNAryFunctionRegistry.getInstance().add(nAryFactory); + try { + Group group = new Group(EMPTY_ASSIGNMENT); + group.addGroupElement(new GroupElem("naryDistinct", + new AggregateFunctionCall(List.of(Var.of("a"), Var.of("b")), nAryFactory.getIri(), false))); + try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { + assertThat(gi.next().getBinding("naryDistinct").getValue()) + .isEqualTo(VF.createLiteral("0", XSD.INTEGER)); + } + } finally { + CustomAggregateNAryFunctionRegistry.getInstance().remove(nAryFactory); + } + } + + @Test + public void testCustomNAryAggregateFunction_WrongIri() throws QueryEvaluationException { + Group group = new Group(EMPTY_ASSIGNMENT); + group.addGroupElement(new GroupElem("narySum", + new AggregateFunctionCall(List.of(Var.of("a"), Var.of("b")), "urn:unknown:nary", false))); + try (GroupIterator gi = new GroupIterator(EVALUATOR, group, EmptyBindingSet.getInstance(), CONTEXT)) { + assertThatExceptionOfType(QueryEvaluationException.class) + .isThrownBy(() -> gi.next().getBinding("narySum").getValue()); + } + } + @Test public void testGroupIteratorClose() throws QueryEvaluationException, InterruptedException { // Lock which is already locked to block the thread driving the iteration @@ -356,6 +473,148 @@ public SumCollector getCollector() { } } + private static final class FakeAggregateNAryFunctionFactory implements AggregateNAryFunctionFactory { + @Override + public String getIri() { + return "https://www.rdf4j.org/aggregate#nary"; + } + + @Override + public AggregateNAryFunction buildFunction( + BiFunction evaluationStepByIndex) { + return new AggregateNAryFunction<>(evaluationStepByIndex) { + private ValueExprEvaluationException typeError = null; + + @Override + public void processAggregate(BindingSet bindingSet, Predicate> distinctValue, + SumCollector sum) + throws QueryEvaluationException { + if (typeError != null) { + // halt further processing if a type error has been raised + return; + } + Value v = evaluate(0, bindingSet); // take first argument as binding set + Value v2 = evaluate(1, bindingSet); // take second argument as binding set + + if (v instanceof Literal) { + if (distinctValue.test(List.of(v))) { + Literal nextLiteral = (Literal) v; + if (nextLiteral.getDatatype() != null + && XMLDatatypeUtil.isNumericDatatype(nextLiteral.getDatatype())) { + sum.value = MathUtil.compute(sum.value, nextLiteral, MathExpr.MathOp.PLUS); + } else { + typeError = new ValueExprEvaluationException("not a number: " + v); + } + } else { + typeError = new ValueExprEvaluationException("not a number: " + v); + } + } + + if (v2 instanceof Literal) { + if (distinctValue.test(List.of(v2))) { + Literal nextLiteral = (Literal) v2; + if (nextLiteral.getDatatype() != null + && XMLDatatypeUtil.isNumericDatatype(nextLiteral.getDatatype())) { + sum.value = MathUtil.compute(sum.value, nextLiteral, MathExpr.MathOp.PLUS); + } else { + typeError = new ValueExprEvaluationException("not a number: " + v2); + } + } else { + typeError = new ValueExprEvaluationException("not a number: " + v2); + } + } + } + }; + } + + @Override + public SumCollector getCollector() { + return new SumCollector(); + } + } + + private static final class DistinctTupleTouchingAggregateNAryFunctionFactory + implements AggregateNAryFunctionFactory { + @Override + public String getIri() { + return "https://www.rdf4j.org/aggregate#nary-distinct-touching"; + } + + @Override + public AggregateNAryFunction buildFunction( + BiFunction evaluationStepByIndex) { + return new AggregateNAryFunction<>(evaluationStepByIndex) { + + @Override + public void processAggregate(BindingSet bindingSet, Predicate> distinctValue, + SumCollector sumCollector) throws QueryEvaluationException { + List tuple = new ArrayList<>(2); + tuple.add(evaluate(0, bindingSet)); + tuple.add(evaluate(1, bindingSet)); + distinctValue.test(tuple); + } + }; + } + + @Override + public SumCollector getCollector() { + return new SumCollector(); + } + } + + private static final class SingleArgAggregateNAryFunctionFactory implements AggregateNAryFunctionFactory { + @Override + public String getIri() { + return "https://www.rdf4j.org/aggregate#nary-single-arg"; + } + + @Override + public int getMinNumberOfArguments() { + return 1; + } + + @Override + public int getMaxNumberOfArguments() { + return 1; + } + + @Override + public AggregateNAryFunction buildFunction( + BiFunction evaluationStepByIndex) { + return new AggregateNAryFunction<>(evaluationStepByIndex) { + private ValueExprEvaluationException typeError = null; + + @Override + public void processAggregate(BindingSet bindingSet, Predicate> distinctValue, + SumCollector sum) + throws QueryEvaluationException { + if (typeError != null) { + return; + } + Value v = evaluate(0, bindingSet); + if (v instanceof Literal) { + if (distinctValue.test(List.of(v))) { + Literal nextLiteral = (Literal) v; + if (nextLiteral.getDatatype() != null + && XMLDatatypeUtil.isNumericDatatype(nextLiteral.getDatatype())) { + sum.value = MathUtil.compute(sum.value, nextLiteral, MathExpr.MathOp.PLUS); + } else { + typeError = new ValueExprEvaluationException("not a number: " + v); + } + } else { + typeError = new ValueExprEvaluationException("not a number: " + v); + } + } + } + }; + } + + @Override + public SumCollector getCollector() { + return new SumCollector(); + } + } + /** * Dummy collector to verify custom aggregate functions */ diff --git a/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/AbstractNAryAggregateOperator.java b/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/AbstractNAryAggregateOperator.java new file mode 100644 index 00000000000..ff05809c7ef --- /dev/null +++ b/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/AbstractNAryAggregateOperator.java @@ -0,0 +1,67 @@ +/******************************************************************************* + * Copyright (c) 2025 Eclipse RDF4J contributors. + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Distribution License v1.0 + * which accompanies this distribution, and is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * SPDX-License-Identifier: BSD-3-Clause + *******************************************************************************/ +package org.eclipse.rdf4j.query.algebra; + +import java.util.List; +import java.util.Objects; + +/** + * Base class for n-ary aggregate operators. + * + * @author Nik Kozlov + */ +public abstract class AbstractNAryAggregateOperator extends NAryValueOperator implements AggregateOperator { + + private static final long serialVersionUID = 1L; + + private boolean distinct = false; + + protected AbstractNAryAggregateOperator(List args) { + this(args, false); + } + + protected AbstractNAryAggregateOperator(List args, boolean distinct) { + super(args); + this.distinct = distinct; + } + + @Override + public void setDistinct(boolean distinct) { + this.distinct = distinct; + } + + @Override + public boolean isDistinct() { + return this.distinct; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof AbstractNAryAggregateOperator)) { + return false; + } + if (!super.equals(o)) { + return false; + } + AbstractNAryAggregateOperator that = (AbstractNAryAggregateOperator) o; + return distinct == that.distinct; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), distinct); + } + + @Override + public AbstractNAryAggregateOperator clone() { + return (AbstractNAryAggregateOperator) super.clone(); + } +} diff --git a/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/AggregateFunctionCall.java b/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/AggregateFunctionCall.java index d63a3235103..15a5b38fa1a 100644 --- a/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/AggregateFunctionCall.java +++ b/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/AggregateFunctionCall.java @@ -8,9 +8,13 @@ * * SPDX-License-Identifier: BSD-3-Clause ******************************************************************************/ +// Some portions generated by Codex package org.eclipse.rdf4j.query.algebra; +import java.util.ArrayList; +import java.util.List; + import org.eclipse.rdf4j.common.annotation.Experimental; /** @@ -20,17 +24,17 @@ * @author Tomas Kovachev t.kovachev1996@gmail.com */ @Experimental -public class AggregateFunctionCall extends AbstractAggregateOperator { +public class AggregateFunctionCall extends AbstractNAryAggregateOperator { protected String iri; public AggregateFunctionCall(String iri, boolean distinct) { - super(null, distinct); + super(new ArrayList<>(), distinct); this.iri = iri; } - public AggregateFunctionCall(ValueExpr arg, String iri, boolean distinct) { - super(arg, distinct); + public AggregateFunctionCall(List args, String iri, boolean distinct) { + super(args instanceof ArrayList ? args : new ArrayList<>(args), distinct); this.iri = iri; } diff --git a/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/NAryValueOperator.java b/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/NAryValueOperator.java index d87e7ba0a27..f4273c8524e 100644 --- a/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/NAryValueOperator.java +++ b/core/queryalgebra/model/src/main/java/org/eclipse/rdf4j/query/algebra/NAryValueOperator.java @@ -51,6 +51,9 @@ protected NAryValueOperator(List args) { public void setArguments(List args) { this.args = args; + for (ValueExpr arg : args) { + arg.setParentNode(this); + } } public List getArguments() { diff --git a/core/queryalgebra/model/src/test/java/org/eclipse/rdf4j/query/algebra/AggregateFunctionCallTest.java b/core/queryalgebra/model/src/test/java/org/eclipse/rdf4j/query/algebra/AggregateFunctionCallTest.java new file mode 100644 index 00000000000..a61713dd70c --- /dev/null +++ b/core/queryalgebra/model/src/test/java/org/eclipse/rdf4j/query/algebra/AggregateFunctionCallTest.java @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2026 Eclipse RDF4J contributors. + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Distribution License v1.0 + * which accompanies this distribution, and is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * SPDX-License-Identifier: BSD-3-Clause + *******************************************************************************/ +// Some portions generated by Codex +package org.eclipse.rdf4j.query.algebra; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.eclipse.rdf4j.model.impl.SimpleValueFactory; +import org.junit.jupiter.api.Test; + +class AggregateFunctionCallTest { + + @Test + void constructorSetsParentForEachArgument() { + Var argument = Var.of("a"); + AggregateFunctionCall aggregateFunctionCall = new AggregateFunctionCall(List.of(argument), "urn:test", false); + + assertThat(argument.getParentNode()).isSameAs(aggregateFunctionCall); + } + + @Test + void constructorCopiesArgumentsIntoMutableList() { + Var argument = Var.of("a"); + AggregateFunctionCall aggregateFunctionCall = new AggregateFunctionCall(List.of(argument), "urn:test", false); + ValueConstant replacement = new ValueConstant(SimpleValueFactory.getInstance().createLiteral(1)); + + aggregateFunctionCall.replaceChildNode(argument, replacement); + + assertThat(aggregateFunctionCall.getArguments()).containsExactly(replacement); + } +} diff --git a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java index 67336982f22..0535abbc46d 100644 --- a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java +++ b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java @@ -8,6 +8,7 @@ * * SPDX-License-Identifier: BSD-3-Clause *******************************************************************************/ +// Some portions generated by Codex package org.eclipse.rdf4j.query.parser.sparql; import java.util.ArrayList; @@ -105,7 +106,9 @@ import org.eclipse.rdf4j.query.algebra.helpers.collectors.StatementPatternCollector; import org.eclipse.rdf4j.query.algebra.helpers.collectors.VarNameCollector; import org.eclipse.rdf4j.query.impl.ListBindingSet; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory; import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry; import org.eclipse.rdf4j.query.parser.sparql.ast.ASTAbs; import org.eclipse.rdf4j.query.parser.sparql.ast.ASTAnd; import org.eclipse.rdf4j.query.parser.sparql.ast.ASTAskQuery; @@ -1994,14 +1997,26 @@ public MathExpr visit(ASTMath node, Object data) throws VisitorException { public Object visit(ASTFunctionCall node, Object data) throws VisitorException { ValueConstant uriNode = (ValueConstant) node.jjtGetChild(0).jjtAccept(this, null); IRI functionURI = (IRI) uriNode.getValue(); - if (CustomAggregateFunctionRegistry.getInstance().has(functionURI.stringValue()) || node.isDistinct()) { - AggregateFunctionCall aggregateCall = new AggregateFunctionCall(functionURI.stringValue(), + String functionIri = functionURI.stringValue(); + var nAryFactory = CustomAggregateNAryFunctionRegistry.getInstance().get(functionIri); + boolean hasNAryFactory = nAryFactory.isPresent(); + boolean hasUnaryFactory = CustomAggregateFunctionRegistry.getInstance().has(functionIri); + if (hasNAryFactory || hasUnaryFactory || node.isDistinct()) { + AggregateFunctionCall aggregateCall = new AggregateFunctionCall(functionIri, node.isDistinct()); - if (node.jjtGetNumChildren() > 2) { - throw new IllegalArgumentException("Custom aggregate functions cannot have more than one argument"); + for (int i = 1; i < node.jjtGetNumChildren(); i++) { + Node argNode = node.jjtGetChild(i); + aggregateCall.addArgument(castToValueExpr(argNode.jjtAccept(this, null))); + } + + int argumentCount = aggregateCall.getArguments().size(); + if (argumentCount == 0) { + throw new IllegalArgumentException("Aggregate function calls must have at least one argument"); + } + if (hasNAryFactory && (argumentCount > 1 || !hasUnaryFactory)) { + validateNAryAggregateArity(functionIri, nAryFactory.get(), argumentCount); } - Node argNode = node.jjtGetChild(1); - aggregateCall.setArg(castToValueExpr(argNode.jjtAccept(this, null))); + return aggregateCall; } else { if (node.isDistinct()) { @@ -2016,6 +2031,24 @@ public Object visit(ASTFunctionCall node, Object data) throws VisitorException { } } + private void validateNAryAggregateArity(String functionIri, AggregateNAryFunctionFactory factory, + int argumentCount) { + int minimum = factory.getMinNumberOfArguments(); + int maximum = factory.getMaxNumberOfArguments(); + if (minimum < 0 || maximum < minimum) { + throw new IllegalArgumentException("Custom n-ary aggregate function '" + functionIri + + "' has invalid arity declaration"); + } + if (argumentCount < minimum) { + throw new IllegalArgumentException("Custom n-ary aggregate function calls must have at least " + minimum + + " arguments"); + } + if (argumentCount > maximum) { + throw new IllegalArgumentException("Custom n-ary aggregate function calls must have at most " + maximum + + " arguments"); + } + } + @Override public FunctionCall visit(ASTEncodeForURI node, Object data) throws VisitorException { return createFunctionCall(FN.ENCODE_FOR_URI.stringValue(), node, 1, 1); diff --git a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateFunction.java b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateFunction.java index a260eb94cad..e339a862785 100644 --- a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateFunction.java +++ b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateFunction.java @@ -23,7 +23,7 @@ * @param */ @Experimental -public abstract class AggregateFunction { +public abstract class AggregateFunction implements AggregateProcessor { protected final Function evaluationStep; @@ -31,6 +31,7 @@ public AggregateFunction(Function evaluationStep) { this.evaluationStep = evaluationStep; } + @Override public abstract void processAggregate(BindingSet bindingSet, Predicate distinctValue, T agv) throws QueryEvaluationException; diff --git a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateNAryFunction.java b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateNAryFunction.java new file mode 100644 index 00000000000..172da2afc2b --- /dev/null +++ b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateNAryFunction.java @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2025 Eclipse RDF4J contributors. + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Distribution License v1.0 + * which accompanies this distribution, and is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * SPDX-License-Identifier: BSD-3-Clause + *******************************************************************************/ +// Some portions generated by Codex +package org.eclipse.rdf4j.query.parser.sparql.aggregate; + +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Predicate; + +import org.eclipse.rdf4j.common.annotation.Experimental; +import org.eclipse.rdf4j.model.Value; +import org.eclipse.rdf4j.query.BindingSet; +import org.eclipse.rdf4j.query.QueryEvaluationException; + +/** + * N-ary aggregate function processor. + * + * @param + * @param + * + * @author Nik Kozlov + */ +@Experimental +public abstract class AggregateNAryFunction implements AggregateProcessor> { + + protected final BiFunction evaluationStepByIndex; + + protected AggregateNAryFunction(BiFunction evaluationStepByIndex) { + this.evaluationStepByIndex = evaluationStepByIndex; + } + + /** + * Process an aggregate with tuple-level distinctness for n-ary functions. + * + * @param bindingSet the current binding set + * @param distinctTuple predicate to check if the tuple of argument values is distinct. The tuple may contain an + * arbitrary amount of arguments, therefore if necessary single argument distinctness can be + * checked inside the predicate. Mixing argument sizes of tuples is not recommended. Note: do + * not mutate (or reuse) a mutable {@link List} instance after passing it to + * {@code distinctTuple.test(...)} because the predicate may keep it in a hash-based set. If + * you construct tuples using a mutable list, pass an immutable snapshot (e.g. + * {@code List.copyOf(tuple)}). + * @param agv the aggregate collector + * @throws QueryEvaluationException if evaluation fails + */ + public abstract void processAggregate(BindingSet bindingSet, Predicate> distinctTuple, T agv) + throws QueryEvaluationException; + + protected Value evaluate(Integer index, BindingSet s) throws QueryEvaluationException { + return evaluationStepByIndex.apply(index, s); + } +} diff --git a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateNAryFunctionFactory.java b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateNAryFunctionFactory.java new file mode 100644 index 00000000000..9168e789137 --- /dev/null +++ b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateNAryFunctionFactory.java @@ -0,0 +1,63 @@ +/******************************************************************************* + * Copyright (c) 2025 Eclipse RDF4J contributors. + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Distribution License v1.0 + * which accompanies this distribution, and is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * SPDX-License-Identifier: BSD-3-Clause + *******************************************************************************/ +// Some portions generated by Codex +package org.eclipse.rdf4j.query.parser.sparql.aggregate; + +import java.util.function.BiFunction; + +import org.eclipse.rdf4j.common.annotation.Experimental; +import org.eclipse.rdf4j.model.Value; +import org.eclipse.rdf4j.query.BindingSet; + +/** + * Factory for a registered {@link AggregateNAryFunction} that is evaluated over multiple arguments. + * + * @author Nik Kozlov + */ +@Experimental +public interface AggregateNAryFunctionFactory { + + /** + * @return the identifier associated with given function + */ + String getIri(); + + /** + * Lower bound for accepted number of aggregate arguments. + * + * @return minimum accepted number of arguments + */ + default int getMinNumberOfArguments() { + return 2; + } + + /** + * Upper bound for accepted number of aggregate arguments. + * + * @return maximum accepted number of arguments + */ + default int getMaxNumberOfArguments() { + return Integer.MAX_VALUE; + } + + /** + * Builds an aggregate function with input evaluation step + * + * @param evaluationStepByIndex used to process values from an iterator's binding set + * @return an aggregate function evaluator + */ + AggregateNAryFunction buildFunction(BiFunction evaluationStepByIndex); + + /** + * @return result collector associated with given function type + */ + AggregateCollector getCollector(); +} diff --git a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateProcessor.java b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateProcessor.java new file mode 100644 index 00000000000..234d835d544 --- /dev/null +++ b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/AggregateProcessor.java @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2025 Eclipse RDF4J contributors. + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Distribution License v1.0 + * which accompanies this distribution, and is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * SPDX-License-Identifier: BSD-3-Clause + *******************************************************************************/ +package org.eclipse.rdf4j.query.parser.sparql.aggregate; + +import java.util.function.Predicate; + +import org.eclipse.rdf4j.query.BindingSet; +import org.eclipse.rdf4j.query.QueryEvaluationException; + +/** + * Common interface for processing aggregate functions. + * + * @param + * @param + * + * @author Nik Kozlov + */ +public interface AggregateProcessor { + void processAggregate(BindingSet bindingSet, Predicate distinctValue, T agv) + throws QueryEvaluationException; +} diff --git a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/CustomAggregateNAryFunctionRegistry.java b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/CustomAggregateNAryFunctionRegistry.java new file mode 100644 index 00000000000..3660fdef1fb --- /dev/null +++ b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/aggregate/CustomAggregateNAryFunctionRegistry.java @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright (c) 2025 Eclipse RDF4J contributors. + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Distribution License v1.0 + * which accompanies this distribution, and is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * SPDX-License-Identifier: BSD-3-Clause + *******************************************************************************/ +package org.eclipse.rdf4j.query.parser.sparql.aggregate; + +import org.eclipse.rdf4j.common.annotation.Experimental; +import org.eclipse.rdf4j.common.lang.service.ServiceRegistry; + +/** + * {@link ServiceRegistry} implementation that stores available custom aggregate n-ary functions that can be used during + * query evaluation. + * + * @author Nik Kozlov + */ +@Experimental +public class CustomAggregateNAryFunctionRegistry extends ServiceRegistry { + + private static final CustomAggregateNAryFunctionRegistry instance = new CustomAggregateNAryFunctionRegistry(); + + public static CustomAggregateNAryFunctionRegistry getInstance() { + return CustomAggregateNAryFunctionRegistry.instance; + } + + public CustomAggregateNAryFunctionRegistry() { + super(AggregateNAryFunctionFactory.class); + } + + @Override + protected String getKey(AggregateNAryFunctionFactory service) { + return service.getIri(); + } +} diff --git a/core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/SPARQLParserTest.java b/core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/SPARQLParserTest.java index 1070a99dfa7..ae46c66bc72 100644 --- a/core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/SPARQLParserTest.java +++ b/core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/SPARQLParserTest.java @@ -8,6 +8,7 @@ * * SPDX-License-Identifier: BSD-3-Clause *******************************************************************************/ +// Some portions generated by Codex package org.eclipse.rdf4j.query.parser.sparql; import static org.assertj.core.api.Assertions.assertThat; @@ -33,6 +34,7 @@ import java.util.Iterator; import java.util.List; import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Function; import org.eclipse.rdf4j.model.Model; @@ -73,7 +75,10 @@ import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateCollector; import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunction; import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunctionFactory; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunction; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory; import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry; +import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry; import org.eclipse.rdf4j.rio.RDFFormat; import org.eclipse.rdf4j.rio.Rio; import org.junit.jupiter.api.AfterEach; @@ -596,7 +601,7 @@ public void testProjectionHandling_UnknownFunctionCallWithDistinct() { } @Test - public void testProjectionHandling_FunctionCallWithArgsFails() { + public void testProjectionHandling_FunctionCallWithArgsDoesNotFail() { var factory = buildDummyFactory(); String query = "prefix rj: " + "SELECT (rj:x(?o, ?p) AS ?o1) \n" @@ -605,15 +610,57 @@ public void testProjectionHandling_FunctionCallWithArgsFails() { + "} GROUP BY ?s ?o"; try { CustomAggregateFunctionRegistry.getInstance().add(factory); - parser.parseQuery(query, null); - fail("Should not be able to parse function calls with multiple args"); - } catch (Exception e) { - assertTrue(e instanceof IllegalArgumentException); + var tupleExpr = parser.parseQuery(query, null).getTupleExpr(); + assertTrue(tupleExpr instanceof QueryRoot); + tupleExpr = ((QueryRoot) tupleExpr).getArg(); + assertTrue(tupleExpr instanceof Projection); + tupleExpr = ((Projection) tupleExpr).getArg(); + assertTrue(tupleExpr instanceof Extension); + var extensionElements = ((Extension) tupleExpr).getElements(); + assertEquals(1, extensionElements.size()); + assertTrue(extensionElements.get(0).getExpr() instanceof AggregateFunctionCall); + var aggregateCall = (AggregateFunctionCall) extensionElements.get(0).getExpr(); + assertEquals(factory.getIri(), aggregateCall.getIRI()); + assertEquals(2, aggregateCall.getArguments().size()); + } finally { + CustomAggregateFunctionRegistry.getInstance().remove(factory); + } + } + + @Test + public void testProjectionHandling_FunctionCallWithoutArgsFails() { + var factory = buildDummyFactory(); + String query = "prefix rj: " + + "SELECT (rj:x() AS ?o1) \n" + + "WHERE {\n" + + " ?s ?p ?o \n" + + "} GROUP BY ?s ?o"; + try { + CustomAggregateFunctionRegistry.getInstance().add(factory); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> parser.parseQuery(query, null)) + .withMessageStartingWith("Aggregate function calls must have at least one argument"); } finally { CustomAggregateFunctionRegistry.getInstance().remove(factory); } } + @Test + public void testProjectionHandling_NAryFunctionCallWithTooFewArgsFails() { + var factory = buildDummyNAryFactory(); + String query = "prefix rj: " + + "SELECT (rj:nary(?o) AS ?o1) \n" + + "WHERE {\n" + + " ?s ?p ?o \n" + + "} GROUP BY ?s ?o"; + try { + CustomAggregateNAryFunctionRegistry.getInstance().add(factory); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> parser.parseQuery(query, null)) + .withMessageStartingWith("Custom n-ary aggregate function calls must have at least 2 arguments"); + } finally { + CustomAggregateNAryFunctionRegistry.getInstance().remove(factory); + } + } + @Test public void testGroupByProjectionHandling_Aggregate_NonSimpleExpr() { String query = "SELECT (COUNT(?s) as ?count) (?o + ?s AS ?o1) \n" @@ -1173,6 +1220,25 @@ public AggregateCollector getCollector() { }; } + private AggregateNAryFunctionFactory buildDummyNAryFactory() { + return new AggregateNAryFunctionFactory() { + @Override + public String getIri() { + return "https://www.rdf4j.org/aggregate#nary"; + } + + @Override + public AggregateNAryFunction buildFunction(BiFunction evaluationStepByIndex) { + return null; + } + + @Override + public AggregateCollector getCollector() { + return null; + } + }; + } + private void verifySerializable(QueryModelNode tupleExpr) { byte[] bytes = objectToBytes(tupleExpr); QueryModelNode parsed = (QueryModelNode) bytesToObject(bytes);