Skip to content

Commit 72c1fa7

Browse files
authored
Merge main into develop (#5715)
2 parents c17420c + 13ca7b6 commit 72c1fa7

13 files changed

Lines changed: 822 additions & 46 deletions

File tree

core/queryalgebra/evaluation/src/main/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIterator.java

Lines changed: 138 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*
99
* SPDX-License-Identifier: BSD-3-Clause
1010
*******************************************************************************/
11+
// Some portions generated by Codex
1112
package org.eclipse.rdf4j.query.algebra.evaluation.iterator;
1213

1314
import java.util.ArrayList;
@@ -19,6 +20,7 @@
1920
import java.util.Random;
2021
import java.util.Set;
2122
import java.util.function.BiConsumer;
23+
import java.util.function.BiFunction;
2224
import java.util.function.Function;
2325
import java.util.function.Predicate;
2426
import java.util.function.Supplier;
@@ -50,6 +52,7 @@
5052
import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp;
5153
import org.eclipse.rdf4j.query.algebra.Max;
5254
import org.eclipse.rdf4j.query.algebra.Min;
55+
import org.eclipse.rdf4j.query.algebra.NAryValueOperator;
5356
import org.eclipse.rdf4j.query.algebra.Sample;
5457
import org.eclipse.rdf4j.query.algebra.Sum;
5558
import org.eclipse.rdf4j.query.algebra.UnaryValueOperator;
@@ -65,7 +68,11 @@
6568
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;
6669
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateCollector;
6770
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunction;
71+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunctionFactory;
72+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory;
73+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateProcessor;
6874
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry;
75+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry;
6976

7077
/**
7178
* @author David Huynh
@@ -349,7 +356,7 @@ private List<Entry> emptySolutionSpecialCase(List<AggregatePredicateCollectorSup
349356
// Even in the case that the Count is of a constant value.
350357
predicates.add(ALWAYS_FALSE_VALUE);
351358
} else {
352-
predicates.add(ALWAYS_TRUE_VALUE);
359+
predicates.add(ag.makePotentialDistinctTest.get());
353360
}
354361
}
355362
final Entry entry = new Entry(null, collectors, predicates);
@@ -409,11 +416,11 @@ public long getSize() {
409416
*/
410417
private static class AggregatePredicateCollectorSupplier<T extends AggregateCollector, D> {
411418
public final String name;
412-
private final AggregateFunction<T, D> agg;
419+
private final AggregateProcessor<T, D> agg;
413420
private final Supplier<Predicate<D>> makePotentialDistinctTest;
414421
private final Supplier<T> makeAggregateCollector;
415422

416-
public AggregatePredicateCollectorSupplier(AggregateFunction<T, D> agg,
423+
public AggregatePredicateCollectorSupplier(AggregateProcessor<T, D> agg,
417424
Supplier<Predicate<D>> makePotentialDistinctTest, Supplier<T> makeAggregateCollector, String name) {
418425
super();
419426
this.agg = agg;
@@ -431,6 +438,7 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
431438
private static final Predicate<Value> ALWAYS_TRUE_VALUE = t -> true;
432439
private static final Predicate<Value> ALWAYS_FALSE_VALUE = t -> false;
433440
private static final Supplier<Predicate<Value>> ALWAYS_TRUE_VALUE_SUPPLIER = () -> ALWAYS_TRUE_VALUE;
441+
private static final Supplier<Predicate<List<Value>>> ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER = () -> t -> true;
434442

435443
private AggregatePredicateCollectorSupplier<?, ?> create(GroupElem ge, ValueFactory vf)
436444
throws QueryEvaluationException {
@@ -444,57 +452,68 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
444452
return new AggregatePredicateCollectorSupplier<>(wildCardCountAggregate, potentialDistinctTest,
445453
() -> new CountCollector(vf), ge.getName());
446454
} else {
447-
QueryStepEvaluator f = precompileArg(operator);
455+
QueryStepEvaluator f = precompileUnaryArg(operator);
448456
CountAggregate agg = new CountAggregate(f);
449-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
457+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
450458
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new CountCollector(vf),
451459
ge.getName());
452460
}
453461
} else if (operator instanceof Min) {
454-
MinAggregate agg = new MinAggregate(precompileArg(operator), shouldValueComparisonBeStrict());
455-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
462+
MinAggregate agg = new MinAggregate(precompileUnaryArg(operator), shouldValueComparisonBeStrict());
463+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
456464
return new AggregatePredicateCollectorSupplier<>(agg, predicate, ValueCollector::new, ge.getName());
457465
} else if (operator instanceof Max) {
458-
MaxAggregate agg = new MaxAggregate(precompileArg(operator), shouldValueComparisonBeStrict());
459-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
466+
MaxAggregate agg = new MaxAggregate(precompileUnaryArg(operator), shouldValueComparisonBeStrict());
467+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
460468
return new AggregatePredicateCollectorSupplier<>(agg, predicate, ValueCollector::new, ge.getName());
461469
} else if (operator instanceof Sum) {
462470

463-
SumAggregate agg = new SumAggregate(precompileArg(operator));
464-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
471+
SumAggregate agg = new SumAggregate(precompileUnaryArg(operator));
472+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
465473
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new IntegerCollector(vf),
466474
ge.getName());
467475
} else if (operator instanceof Avg) {
468-
AvgAggregate agg = new AvgAggregate(precompileArg(operator));
469-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
476+
AvgAggregate agg = new AvgAggregate(precompileUnaryArg(operator));
477+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
470478
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new AvgCollector(vf), ge.getName());
471479
} else if (operator instanceof Sample) {
472-
SampleAggregate agg = new SampleAggregate(precompileArg(operator));
473-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
480+
SampleAggregate agg = new SampleAggregate(precompileUnaryArg(operator));
481+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
474482
return new AggregatePredicateCollectorSupplier<>(agg, predicate, SampleCollector::new, ge.getName());
475483
} else if (operator instanceof GroupConcat) {
476484
ValueExpr separatorExpr = ((GroupConcat) operator).getSeparator();
477485
ConcatAggregate agg;
478486
if (separatorExpr != null) {
479487
Value separatorValue = strategy.evaluate(separatorExpr, parentBindings);
480-
agg = new ConcatAggregate(precompileArg(operator), separatorValue.stringValue());
488+
agg = new ConcatAggregate(precompileUnaryArg(operator), separatorValue.stringValue());
481489
} else {
482-
agg = new ConcatAggregate(precompileArg(operator));
490+
agg = new ConcatAggregate(precompileUnaryArg(operator));
483491
}
484-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
492+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
485493
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new StringBuilderCollector(vf),
486494
ge.getName());
487495
} else if (operator instanceof AggregateFunctionCall) {
488496
var aggOperator = (AggregateFunctionCall) operator;
489-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
490-
var factory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI());
497+
int argumentCount = aggOperator.getArguments().size();
498+
var nAryFactory = CustomAggregateNAryFunctionRegistry.getInstance().get(aggOperator.getIRI());
499+
var unaryFactory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI());
491500

492-
var function = factory.orElseThrow(
493-
() -> new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'"))
494-
.buildFunction(precompileArg(aggOperator));
495-
return new AggregatePredicateCollectorSupplier<>(function, predicate, () -> factory.get().getCollector(),
496-
ge.getName());
501+
if (argumentCount == 1 && unaryFactory.isPresent()) {
502+
return createUnaryCustomAggregate(ge, operator, aggOperator, unaryFactory.get());
503+
}
504+
if (nAryFactory.isPresent()) {
505+
validateNAryAggregateArity(aggOperator, nAryFactory.get(), argumentCount);
506+
return createNAryCustomAggregate(ge, operator, aggOperator, nAryFactory.get());
507+
}
508+
if (unaryFactory.isPresent()) {
509+
if (argumentCount != 1) {
510+
throw new QueryEvaluationException("Custom unary aggregate function '" + aggOperator.getIRI()
511+
+ "' expects exactly 1 argument, got " + argumentCount);
512+
}
513+
return createUnaryCustomAggregate(ge, operator, aggOperator, unaryFactory.get());
514+
}
497515

516+
throw new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'");
498517
}
499518

500519
return null;
@@ -508,15 +527,73 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
508527
* @return a supplier that returns a predicate that tests if the value is distinct, or always returns true if the
509528
* operator is not distinct.
510529
*/
511-
private Supplier<Predicate<Value>> createDistinctValueTest(AggregateOperator operator) {
530+
private Supplier<Predicate<Value>> createDistinctSingleValueTest(AggregateOperator operator) {
512531
return operator.isDistinct() ? DistinctValues::new : ALWAYS_TRUE_VALUE_SUPPLIER;
513532
}
514533

515-
private QueryStepEvaluator precompileArg(AggregateOperator operator) {
534+
/**
535+
* Create a predicate that tests if the tuple of values is distinct (returning true if the tuple was not seen yet),
536+
* or always returns true if the operator is not distinct.
537+
*
538+
* @param operator
539+
* @return a supplier that returns a predicate that tests if the tuple of values is distinct, or always returns true
540+
* if the operator is not distinct.
541+
*/
542+
private Supplier<Predicate<List<Value>>> createDistinctTupleValueTest(AggregateOperator operator) {
543+
return operator.isDistinct() ? DistinctTupleValues::new : ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER;
544+
}
545+
546+
private AggregatePredicateCollectorSupplier<?, ?> createUnaryCustomAggregate(GroupElem ge,
547+
AggregateOperator operator,
548+
AggregateFunctionCall aggOperator, AggregateFunctionFactory factory) {
549+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
550+
AggregateFunction function = factory.buildFunction(precompileNAryArg(aggOperator).asUnaryEvaluator(0));
551+
return new AggregatePredicateCollectorSupplier<>(function, predicate, factory::getCollector, ge.getName());
552+
}
553+
554+
private AggregatePredicateCollectorSupplier<?, ?> createNAryCustomAggregate(GroupElem ge,
555+
AggregateOperator operator,
556+
AggregateFunctionCall aggOperator, AggregateNAryFunctionFactory factory) {
557+
Supplier<Predicate<List<Value>>> predicate = createDistinctTupleValueTest(operator);
558+
var function = factory.buildFunction(precompileNAryArg(aggOperator));
559+
return new AggregatePredicateCollectorSupplier<>(function, predicate, factory::getCollector, ge.getName());
560+
}
561+
562+
private void validateNAryAggregateArity(AggregateFunctionCall aggregateOperator,
563+
AggregateNAryFunctionFactory factory,
564+
int argumentCount) {
565+
int minimum = factory.getMinNumberOfArguments();
566+
int maximum = factory.getMaxNumberOfArguments();
567+
if (minimum < 0 || maximum < minimum) {
568+
throw new QueryEvaluationException("Custom n-ary aggregate function '" + aggregateOperator.getIRI()
569+
+ "' has invalid arity declaration");
570+
}
571+
if (argumentCount < minimum) {
572+
throw new QueryEvaluationException(
573+
"Custom n-ary aggregate function '" + aggregateOperator.getIRI() + "' expects at least " + minimum
574+
+ " arguments, got " + argumentCount);
575+
}
576+
if (argumentCount > maximum) {
577+
throw new QueryEvaluationException(
578+
"Custom n-ary aggregate function '" + aggregateOperator.getIRI() + "' expects at most " + maximum
579+
+ " arguments, got " + argumentCount);
580+
}
581+
}
582+
583+
private QueryStepEvaluator precompileUnaryArg(AggregateOperator operator) {
516584
ValueExpr ve = ((UnaryValueOperator) operator).getArg();
517585
return new QueryStepEvaluator(strategy.precompile(ve, context));
518586
}
519587

588+
private NAryQueryStepEvaluator precompileNAryArg(AggregateOperator operator) {
589+
List<ValueExpr> args = ((NAryValueOperator) operator).getArguments();
590+
List<QueryValueEvaluationStep> precompiledArgs = new ArrayList<>(args.size());
591+
for (ValueExpr arg : args) {
592+
precompiledArgs.add(strategy.precompile(arg, context));
593+
}
594+
return new NAryQueryStepEvaluator(precompiledArgs::get);
595+
}
596+
520597
private boolean shouldValueComparisonBeStrict() {
521598
return strategy.getQueryEvaluationMode() == QueryEvaluationMode.STRICT;
522599
}
@@ -623,6 +700,19 @@ public boolean test(Value value) {
623700
}
624701
}
625702

703+
private class DistinctTupleValues implements Predicate<List<Value>> {
704+
private final Set<List<Value>> distinctTuples;
705+
706+
public DistinctTupleValues() {
707+
distinctTuples = cf.createSet();
708+
}
709+
710+
@Override
711+
public boolean test(List<Value> valueTuple) {
712+
return distinctTuples.add(valueTuple);
713+
}
714+
}
715+
626716
private class DistinctBindingSets implements Predicate<BindingSet> {
627717
private final Set<BindingSet> distinctValues;
628718

@@ -884,4 +974,25 @@ public Value apply(BindingSet bindings) {
884974
}
885975
}
886976
}
977+
978+
private static class NAryQueryStepEvaluator implements BiFunction<Integer, BindingSet, Value> {
979+
private final Function<Integer, QueryValueEvaluationStep> evaluationStepFunction;
980+
981+
public NAryQueryStepEvaluator(Function<Integer, QueryValueEvaluationStep> evaluationStepFunction) {
982+
this.evaluationStepFunction = evaluationStepFunction;
983+
}
984+
985+
@Override
986+
public Value apply(Integer index, BindingSet bindings) {
987+
try {
988+
return evaluationStepFunction.apply(index).evaluate(bindings);
989+
} catch (ValueExprEvaluationException e) {
990+
return null; // treat missing or invalid expressions as null
991+
}
992+
}
993+
994+
public QueryStepEvaluator asUnaryEvaluator(Integer index) {
995+
return new QueryStepEvaluator(evaluationStepFunction.apply(index));
996+
}
997+
}
887998
}

0 commit comments

Comments
 (0)