Skip to content

Commit bcc695a

Browse files
nk2IsHerehmottestad
authored andcommitted
GH-5626 N-ary aggregate function implementation
1 parent a3ad055 commit bcc695a

11 files changed

Lines changed: 506 additions & 44 deletions

File tree

core/queryalgebra/evaluation/src/main/java/org/eclipse/rdf4j/query/algebra/evaluation/iterator/GroupIterator.java

Lines changed: 100 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Random;
2020
import java.util.Set;
2121
import java.util.function.BiConsumer;
22+
import java.util.function.BiFunction;
2223
import java.util.function.Function;
2324
import java.util.function.Predicate;
2425
import java.util.function.Supplier;
@@ -50,6 +51,7 @@
5051
import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp;
5152
import org.eclipse.rdf4j.query.algebra.Max;
5253
import org.eclipse.rdf4j.query.algebra.Min;
54+
import org.eclipse.rdf4j.query.algebra.NAryValueOperator;
5355
import org.eclipse.rdf4j.query.algebra.Sample;
5456
import org.eclipse.rdf4j.query.algebra.Sum;
5557
import org.eclipse.rdf4j.query.algebra.UnaryValueOperator;
@@ -64,7 +66,9 @@
6466
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;
6567
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateCollector;
6668
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunction;
69+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateProcessor;
6770
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry;
71+
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry;
6872

6973
/**
7074
* @author David Huynh
@@ -382,11 +386,11 @@ public BindingSet getPrototype() {
382386
*/
383387
private static class AggregatePredicateCollectorSupplier<T extends AggregateCollector, D> {
384388
public final String name;
385-
private final AggregateFunction<T, D> agg;
389+
private final AggregateProcessor<T, D> agg;
386390
private final Supplier<Predicate<D>> makePotentialDistinctTest;
387391
private final Supplier<T> makeAggregateCollector;
388392

389-
public AggregatePredicateCollectorSupplier(AggregateFunction<T, D> agg,
393+
public AggregatePredicateCollectorSupplier(AggregateProcessor<T, D> agg,
390394
Supplier<Predicate<D>> makePotentialDistinctTest, Supplier<T> makeAggregateCollector, String name) {
391395
super();
392396
this.agg = agg;
@@ -404,6 +408,7 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
404408
private static final Predicate<Value> ALWAYS_TRUE_VALUE = t -> true;
405409
private static final Predicate<Value> ALWAYS_FALSE_VALUE = t -> false;
406410
private static final Supplier<Predicate<Value>> ALWAYS_TRUE_VALUE_SUPPLIER = () -> ALWAYS_TRUE_VALUE;
411+
private static final Supplier<Predicate<List<Value>>> ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER = () -> t -> true;
407412

408413
private AggregatePredicateCollectorSupplier<?, ?> create(GroupElem ge, ValueFactory vf)
409414
throws QueryEvaluationException {
@@ -417,57 +422,71 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
417422
return new AggregatePredicateCollectorSupplier<>(wildCardCountAggregate, potentialDistinctTest,
418423
() -> new CountCollector(vf), ge.getName());
419424
} else {
420-
QueryStepEvaluator f = precompileArg(operator);
425+
QueryStepEvaluator f = precompileUnaryArg(operator);
421426
CountAggregate agg = new CountAggregate(f);
422-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
427+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
423428
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new CountCollector(vf),
424429
ge.getName());
425430
}
426431
} else if (operator instanceof Min) {
427-
MinAggregate agg = new MinAggregate(precompileArg(operator), shouldValueComparisonBeStrict());
428-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
432+
MinAggregate agg = new MinAggregate(precompileUnaryArg(operator), shouldValueComparisonBeStrict());
433+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
429434
return new AggregatePredicateCollectorSupplier<>(agg, predicate, ValueCollector::new, ge.getName());
430435
} else if (operator instanceof Max) {
431-
MaxAggregate agg = new MaxAggregate(precompileArg(operator), shouldValueComparisonBeStrict());
432-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
436+
MaxAggregate agg = new MaxAggregate(precompileUnaryArg(operator), shouldValueComparisonBeStrict());
437+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
433438
return new AggregatePredicateCollectorSupplier<>(agg, predicate, ValueCollector::new, ge.getName());
434439
} else if (operator instanceof Sum) {
435440

436-
SumAggregate agg = new SumAggregate(precompileArg(operator));
437-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
441+
SumAggregate agg = new SumAggregate(precompileUnaryArg(operator));
442+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
438443
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new IntegerCollector(vf),
439444
ge.getName());
440445
} else if (operator instanceof Avg) {
441-
AvgAggregate agg = new AvgAggregate(precompileArg(operator));
442-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
446+
AvgAggregate agg = new AvgAggregate(precompileUnaryArg(operator));
447+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
443448
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new AvgCollector(vf), ge.getName());
444449
} else if (operator instanceof Sample) {
445-
SampleAggregate agg = new SampleAggregate(precompileArg(operator));
446-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
450+
SampleAggregate agg = new SampleAggregate(precompileUnaryArg(operator));
451+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
447452
return new AggregatePredicateCollectorSupplier<>(agg, predicate, SampleCollector::new, ge.getName());
448453
} else if (operator instanceof GroupConcat) {
449454
ValueExpr separatorExpr = ((GroupConcat) operator).getSeparator();
450455
ConcatAggregate agg;
451456
if (separatorExpr != null) {
452457
Value separatorValue = strategy.evaluate(separatorExpr, parentBindings);
453-
agg = new ConcatAggregate(precompileArg(operator), separatorValue.stringValue());
458+
agg = new ConcatAggregate(precompileUnaryArg(operator), separatorValue.stringValue());
454459
} else {
455-
agg = new ConcatAggregate(precompileArg(operator));
460+
agg = new ConcatAggregate(precompileUnaryArg(operator));
456461
}
457-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
462+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
458463
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new StringBuilderCollector(vf),
459464
ge.getName());
460465
} else if (operator instanceof AggregateFunctionCall) {
461466
var aggOperator = (AggregateFunctionCall) operator;
462-
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
463-
var factory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI());
464467

465-
var function = factory.orElseThrow(
466-
() -> new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'"))
467-
.buildFunction(precompileArg(aggOperator));
468-
return new AggregatePredicateCollectorSupplier<>(function, predicate, () -> factory.get().getCollector(),
469-
ge.getName());
468+
if (aggOperator.getArguments().size() > 1) {
469+
Supplier<Predicate<List<Value>>> predicate = createDistinctTupleValueTest(operator);
470+
var factory = CustomAggregateNAryFunctionRegistry.getInstance().get(aggOperator.getIRI());
470471

472+
var function = factory.orElseThrow(
473+
() -> new QueryEvaluationException(
474+
"Unknown n-ary aggregate function '" + aggOperator.getIRI() + "'"))
475+
.buildFunction(precompileNAryArg(aggOperator));
476+
return new AggregatePredicateCollectorSupplier<>(function, predicate,
477+
() -> factory.get().getCollector(),
478+
ge.getName());
479+
} else {
480+
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
481+
var factory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI());
482+
483+
var function = factory.orElseThrow(
484+
() -> new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'"))
485+
.buildFunction(precompileNAryArg(aggOperator).asUnaryEvaluator(0));
486+
return new AggregatePredicateCollectorSupplier<>(function, predicate,
487+
() -> factory.get().getCollector(),
488+
ge.getName());
489+
}
471490
}
472491

473492
return null;
@@ -481,15 +500,36 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
481500
* @return a supplier that returns a predicate that tests if the value is distinct, or always returns true if the
482501
* operator is not distinct.
483502
*/
484-
private Supplier<Predicate<Value>> createDistinctValueTest(AggregateOperator operator) {
503+
private Supplier<Predicate<Value>> createDistinctSingleValueTest(AggregateOperator operator) {
485504
return operator.isDistinct() ? DistinctValues::new : ALWAYS_TRUE_VALUE_SUPPLIER;
486505
}
487506

488-
private QueryStepEvaluator precompileArg(AggregateOperator operator) {
507+
/**
508+
* Create a predicate that tests if the tuple of values is distinct (returning true if the tuple was not seen yet),
509+
* or always returns true if the operator is not distinct.
510+
*
511+
* @param operator
512+
* @return a supplier that returns a predicate that tests if the tuple of values is distinct, or always returns true
513+
* if the operator is not distinct.
514+
*/
515+
private Supplier<Predicate<List<Value>>> createDistinctTupleValueTest(AggregateOperator operator) {
516+
return operator.isDistinct() ? DistinctTupleValues::new : ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER;
517+
}
518+
519+
private QueryStepEvaluator precompileUnaryArg(AggregateOperator operator) {
489520
ValueExpr ve = ((UnaryValueOperator) operator).getArg();
490521
return new QueryStepEvaluator(strategy.precompile(ve, context));
491522
}
492523

524+
private NAryQueryStepEvaluator precompileNAryArg(AggregateOperator operator) {
525+
List<ValueExpr> args = ((NAryValueOperator) operator).getArguments();
526+
List<QueryValueEvaluationStep> precompiledArgs = new ArrayList<>(args.size());
527+
for (ValueExpr arg : args) {
528+
precompiledArgs.add(strategy.precompile(arg, context));
529+
}
530+
return new NAryQueryStepEvaluator(precompiledArgs::get);
531+
}
532+
493533
private boolean shouldValueComparisonBeStrict() {
494534
return strategy.getQueryEvaluationMode() == QueryEvaluationMode.STRICT;
495535
}
@@ -596,6 +636,19 @@ public boolean test(Value value) {
596636
}
597637
}
598638

639+
private class DistinctTupleValues implements Predicate<List<Value>> {
640+
private final Set<List<Value>> distinctTuples;
641+
642+
public DistinctTupleValues() {
643+
distinctTuples = cf.createSet();
644+
}
645+
646+
@Override
647+
public boolean test(List<Value> valueTuple) {
648+
return distinctTuples.add(valueTuple);
649+
}
650+
}
651+
599652
private class DistinctBindingSets implements Predicate<BindingSet> {
600653
private final Set<BindingSet> distinctValues;
601654

@@ -857,4 +910,25 @@ public Value apply(BindingSet bindings) {
857910
}
858911
}
859912
}
913+
914+
private static class NAryQueryStepEvaluator implements BiFunction<Integer, BindingSet, Value> {
915+
private final Function<Integer, QueryValueEvaluationStep> evaluationStepFunction;
916+
917+
public NAryQueryStepEvaluator(Function<Integer, QueryValueEvaluationStep> evaluationStepFunction) {
918+
this.evaluationStepFunction = evaluationStepFunction;
919+
}
920+
921+
@Override
922+
public Value apply(Integer index, BindingSet bindings) {
923+
try {
924+
return evaluationStepFunction.apply(index).evaluate(bindings);
925+
} catch (ValueExprEvaluationException e) {
926+
return null; // treat missing or invalid expressions as null
927+
}
928+
}
929+
930+
public QueryStepEvaluator asUnaryEvaluator(Integer index) {
931+
return new QueryStepEvaluator(evaluationStepFunction.apply(index));
932+
}
933+
}
860934
}

0 commit comments

Comments
 (0)