88 *
99 * SPDX-License-Identifier: BSD-3-Clause
1010 *******************************************************************************/
11+ // Some portions generated by Codex
1112package org .eclipse .rdf4j .query .algebra .evaluation .iterator ;
1213
1314import java .util .ArrayList ;
1920import java .util .Random ;
2021import java .util .Set ;
2122import java .util .function .BiConsumer ;
23+ import java .util .function .BiFunction ;
2224import java .util .function .Function ;
2325import java .util .function .Predicate ;
2426import java .util .function .Supplier ;
5052import org .eclipse .rdf4j .query .algebra .MathExpr .MathOp ;
5153import org .eclipse .rdf4j .query .algebra .Max ;
5254import org .eclipse .rdf4j .query .algebra .Min ;
55+ import org .eclipse .rdf4j .query .algebra .NAryValueOperator ;
5356import org .eclipse .rdf4j .query .algebra .Sample ;
5457import org .eclipse .rdf4j .query .algebra .Sum ;
5558import org .eclipse .rdf4j .query .algebra .UnaryValueOperator ;
6568import org .eclipse .rdf4j .query .impl .EmptyBindingSet ;
6669import org .eclipse .rdf4j .query .parser .sparql .aggregate .AggregateCollector ;
6770import org .eclipse .rdf4j .query .parser .sparql .aggregate .AggregateFunction ;
71+ import org .eclipse .rdf4j .query .parser .sparql .aggregate .AggregateFunctionFactory ;
72+ import org .eclipse .rdf4j .query .parser .sparql .aggregate .AggregateNAryFunctionFactory ;
73+ import org .eclipse .rdf4j .query .parser .sparql .aggregate .AggregateProcessor ;
6874import org .eclipse .rdf4j .query .parser .sparql .aggregate .CustomAggregateFunctionRegistry ;
75+ import org .eclipse .rdf4j .query .parser .sparql .aggregate .CustomAggregateNAryFunctionRegistry ;
6976
7077/**
7178 * @author David Huynh
@@ -349,7 +356,7 @@ private List<Entry> emptySolutionSpecialCase(List<AggregatePredicateCollectorSup
349356 // Even in the case that the Count is of a constant value.
350357 predicates .add (ALWAYS_FALSE_VALUE );
351358 } else {
352- predicates .add (ALWAYS_TRUE_VALUE );
359+ predicates .add (ag . makePotentialDistinctTest . get () );
353360 }
354361 }
355362 final Entry entry = new Entry (null , collectors , predicates );
@@ -409,11 +416,11 @@ public long getSize() {
409416 */
410417 private static class AggregatePredicateCollectorSupplier <T extends AggregateCollector , D > {
411418 public final String name ;
412- private final AggregateFunction <T , D > agg ;
419+ private final AggregateProcessor <T , D > agg ;
413420 private final Supplier <Predicate <D >> makePotentialDistinctTest ;
414421 private final Supplier <T > makeAggregateCollector ;
415422
416- public AggregatePredicateCollectorSupplier (AggregateFunction <T , D > agg ,
423+ public AggregatePredicateCollectorSupplier (AggregateProcessor <T , D > agg ,
417424 Supplier <Predicate <D >> makePotentialDistinctTest , Supplier <T > makeAggregateCollector , String name ) {
418425 super ();
419426 this .agg = agg ;
@@ -431,6 +438,7 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
431438 private static final Predicate <Value > ALWAYS_TRUE_VALUE = t -> true ;
432439 private static final Predicate <Value > ALWAYS_FALSE_VALUE = t -> false ;
433440 private static final Supplier <Predicate <Value >> ALWAYS_TRUE_VALUE_SUPPLIER = () -> ALWAYS_TRUE_VALUE ;
441+ private static final Supplier <Predicate <List <Value >>> ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER = () -> t -> true ;
434442
435443 private AggregatePredicateCollectorSupplier <?, ?> create (GroupElem ge , ValueFactory vf )
436444 throws QueryEvaluationException {
@@ -444,57 +452,68 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
444452 return new AggregatePredicateCollectorSupplier <>(wildCardCountAggregate , potentialDistinctTest ,
445453 () -> new CountCollector (vf ), ge .getName ());
446454 } else {
447- QueryStepEvaluator f = precompileArg (operator );
455+ QueryStepEvaluator f = precompileUnaryArg (operator );
448456 CountAggregate agg = new CountAggregate (f );
449- Supplier <Predicate <Value >> predicate = createDistinctValueTest (operator );
457+ Supplier <Predicate <Value >> predicate = createDistinctSingleValueTest (operator );
450458 return new AggregatePredicateCollectorSupplier <>(agg , predicate , () -> new CountCollector (vf ),
451459 ge .getName ());
452460 }
453461 } else if (operator instanceof Min ) {
454- MinAggregate agg = new MinAggregate (precompileArg (operator ), shouldValueComparisonBeStrict ());
455- Supplier <Predicate <Value >> predicate = createDistinctValueTest (operator );
462+ MinAggregate agg = new MinAggregate (precompileUnaryArg (operator ), shouldValueComparisonBeStrict ());
463+ Supplier <Predicate <Value >> predicate = createDistinctSingleValueTest (operator );
456464 return new AggregatePredicateCollectorSupplier <>(agg , predicate , ValueCollector ::new , ge .getName ());
457465 } else if (operator instanceof Max ) {
458- MaxAggregate agg = new MaxAggregate (precompileArg (operator ), shouldValueComparisonBeStrict ());
459- Supplier <Predicate <Value >> predicate = createDistinctValueTest (operator );
466+ MaxAggregate agg = new MaxAggregate (precompileUnaryArg (operator ), shouldValueComparisonBeStrict ());
467+ Supplier <Predicate <Value >> predicate = createDistinctSingleValueTest (operator );
460468 return new AggregatePredicateCollectorSupplier <>(agg , predicate , ValueCollector ::new , ge .getName ());
461469 } else if (operator instanceof Sum ) {
462470
463- SumAggregate agg = new SumAggregate (precompileArg (operator ));
464- Supplier <Predicate <Value >> predicate = createDistinctValueTest (operator );
471+ SumAggregate agg = new SumAggregate (precompileUnaryArg (operator ));
472+ Supplier <Predicate <Value >> predicate = createDistinctSingleValueTest (operator );
465473 return new AggregatePredicateCollectorSupplier <>(agg , predicate , () -> new IntegerCollector (vf ),
466474 ge .getName ());
467475 } else if (operator instanceof Avg ) {
468- AvgAggregate agg = new AvgAggregate (precompileArg (operator ));
469- Supplier <Predicate <Value >> predicate = createDistinctValueTest (operator );
476+ AvgAggregate agg = new AvgAggregate (precompileUnaryArg (operator ));
477+ Supplier <Predicate <Value >> predicate = createDistinctSingleValueTest (operator );
470478 return new AggregatePredicateCollectorSupplier <>(agg , predicate , () -> new AvgCollector (vf ), ge .getName ());
471479 } else if (operator instanceof Sample ) {
472- SampleAggregate agg = new SampleAggregate (precompileArg (operator ));
473- Supplier <Predicate <Value >> predicate = createDistinctValueTest (operator );
480+ SampleAggregate agg = new SampleAggregate (precompileUnaryArg (operator ));
481+ Supplier <Predicate <Value >> predicate = createDistinctSingleValueTest (operator );
474482 return new AggregatePredicateCollectorSupplier <>(agg , predicate , SampleCollector ::new , ge .getName ());
475483 } else if (operator instanceof GroupConcat ) {
476484 ValueExpr separatorExpr = ((GroupConcat ) operator ).getSeparator ();
477485 ConcatAggregate agg ;
478486 if (separatorExpr != null ) {
479487 Value separatorValue = strategy .evaluate (separatorExpr , parentBindings );
480- agg = new ConcatAggregate (precompileArg (operator ), separatorValue .stringValue ());
488+ agg = new ConcatAggregate (precompileUnaryArg (operator ), separatorValue .stringValue ());
481489 } else {
482- agg = new ConcatAggregate (precompileArg (operator ));
490+ agg = new ConcatAggregate (precompileUnaryArg (operator ));
483491 }
484- Supplier <Predicate <Value >> predicate = createDistinctValueTest (operator );
492+ Supplier <Predicate <Value >> predicate = createDistinctSingleValueTest (operator );
485493 return new AggregatePredicateCollectorSupplier <>(agg , predicate , () -> new StringBuilderCollector (vf ),
486494 ge .getName ());
487495 } else if (operator instanceof AggregateFunctionCall ) {
488496 var aggOperator = (AggregateFunctionCall ) operator ;
489- Supplier <Predicate <Value >> predicate = createDistinctValueTest (operator );
490- var factory = CustomAggregateFunctionRegistry .getInstance ().get (aggOperator .getIRI ());
497+ int argumentCount = aggOperator .getArguments ().size ();
498+ var nAryFactory = CustomAggregateNAryFunctionRegistry .getInstance ().get (aggOperator .getIRI ());
499+ var unaryFactory = CustomAggregateFunctionRegistry .getInstance ().get (aggOperator .getIRI ());
491500
492- var function = factory .orElseThrow (
493- () -> new QueryEvaluationException ("Unknown aggregate function '" + aggOperator .getIRI () + "'" ))
494- .buildFunction (precompileArg (aggOperator ));
495- return new AggregatePredicateCollectorSupplier <>(function , predicate , () -> factory .get ().getCollector (),
496- ge .getName ());
501+ if (argumentCount == 1 && unaryFactory .isPresent ()) {
502+ return createUnaryCustomAggregate (ge , operator , aggOperator , unaryFactory .get ());
503+ }
504+ if (nAryFactory .isPresent ()) {
505+ validateNAryAggregateArity (aggOperator , nAryFactory .get (), argumentCount );
506+ return createNAryCustomAggregate (ge , operator , aggOperator , nAryFactory .get ());
507+ }
508+ if (unaryFactory .isPresent ()) {
509+ if (argumentCount != 1 ) {
510+ throw new QueryEvaluationException ("Custom unary aggregate function '" + aggOperator .getIRI ()
511+ + "' expects exactly 1 argument, got " + argumentCount );
512+ }
513+ return createUnaryCustomAggregate (ge , operator , aggOperator , unaryFactory .get ());
514+ }
497515
516+ throw new QueryEvaluationException ("Unknown aggregate function '" + aggOperator .getIRI () + "'" );
498517 }
499518
500519 return null ;
@@ -508,15 +527,73 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
508527 * @return a supplier that returns a predicate that tests if the value is distinct, or always returns true if the
509528 * operator is not distinct.
510529 */
511- private Supplier <Predicate <Value >> createDistinctValueTest (AggregateOperator operator ) {
530+ private Supplier <Predicate <Value >> createDistinctSingleValueTest (AggregateOperator operator ) {
512531 return operator .isDistinct () ? DistinctValues ::new : ALWAYS_TRUE_VALUE_SUPPLIER ;
513532 }
514533
515- private QueryStepEvaluator precompileArg (AggregateOperator operator ) {
534+ /**
535+ * Create a predicate that tests if the tuple of values is distinct (returning true if the tuple was not seen yet),
536+ * or always returns true if the operator is not distinct.
537+ *
538+ * @param operator
539+ * @return a supplier that returns a predicate that tests if the tuple of values is distinct, or always returns true
540+ * if the operator is not distinct.
541+ */
542+ private Supplier <Predicate <List <Value >>> createDistinctTupleValueTest (AggregateOperator operator ) {
543+ return operator .isDistinct () ? DistinctTupleValues ::new : ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER ;
544+ }
545+
546+ private AggregatePredicateCollectorSupplier <?, ?> createUnaryCustomAggregate (GroupElem ge ,
547+ AggregateOperator operator ,
548+ AggregateFunctionCall aggOperator , AggregateFunctionFactory factory ) {
549+ Supplier <Predicate <Value >> predicate = createDistinctSingleValueTest (operator );
550+ AggregateFunction function = factory .buildFunction (precompileNAryArg (aggOperator ).asUnaryEvaluator (0 ));
551+ return new AggregatePredicateCollectorSupplier <>(function , predicate , factory ::getCollector , ge .getName ());
552+ }
553+
554+ private AggregatePredicateCollectorSupplier <?, ?> createNAryCustomAggregate (GroupElem ge ,
555+ AggregateOperator operator ,
556+ AggregateFunctionCall aggOperator , AggregateNAryFunctionFactory factory ) {
557+ Supplier <Predicate <List <Value >>> predicate = createDistinctTupleValueTest (operator );
558+ var function = factory .buildFunction (precompileNAryArg (aggOperator ));
559+ return new AggregatePredicateCollectorSupplier <>(function , predicate , factory ::getCollector , ge .getName ());
560+ }
561+
562+ private void validateNAryAggregateArity (AggregateFunctionCall aggregateOperator ,
563+ AggregateNAryFunctionFactory factory ,
564+ int argumentCount ) {
565+ int minimum = factory .getMinNumberOfArguments ();
566+ int maximum = factory .getMaxNumberOfArguments ();
567+ if (minimum < 0 || maximum < minimum ) {
568+ throw new QueryEvaluationException ("Custom n-ary aggregate function '" + aggregateOperator .getIRI ()
569+ + "' has invalid arity declaration" );
570+ }
571+ if (argumentCount < minimum ) {
572+ throw new QueryEvaluationException (
573+ "Custom n-ary aggregate function '" + aggregateOperator .getIRI () + "' expects at least " + minimum
574+ + " arguments, got " + argumentCount );
575+ }
576+ if (argumentCount > maximum ) {
577+ throw new QueryEvaluationException (
578+ "Custom n-ary aggregate function '" + aggregateOperator .getIRI () + "' expects at most " + maximum
579+ + " arguments, got " + argumentCount );
580+ }
581+ }
582+
583+ private QueryStepEvaluator precompileUnaryArg (AggregateOperator operator ) {
516584 ValueExpr ve = ((UnaryValueOperator ) operator ).getArg ();
517585 return new QueryStepEvaluator (strategy .precompile (ve , context ));
518586 }
519587
588+ private NAryQueryStepEvaluator precompileNAryArg (AggregateOperator operator ) {
589+ List <ValueExpr > args = ((NAryValueOperator ) operator ).getArguments ();
590+ List <QueryValueEvaluationStep > precompiledArgs = new ArrayList <>(args .size ());
591+ for (ValueExpr arg : args ) {
592+ precompiledArgs .add (strategy .precompile (arg , context ));
593+ }
594+ return new NAryQueryStepEvaluator (precompiledArgs ::get );
595+ }
596+
520597 private boolean shouldValueComparisonBeStrict () {
521598 return strategy .getQueryEvaluationMode () == QueryEvaluationMode .STRICT ;
522599 }
@@ -623,6 +700,19 @@ public boolean test(Value value) {
623700 }
624701 }
625702
703+ private class DistinctTupleValues implements Predicate <List <Value >> {
704+ private final Set <List <Value >> distinctTuples ;
705+
706+ public DistinctTupleValues () {
707+ distinctTuples = cf .createSet ();
708+ }
709+
710+ @ Override
711+ public boolean test (List <Value > valueTuple ) {
712+ return distinctTuples .add (valueTuple );
713+ }
714+ }
715+
626716 private class DistinctBindingSets implements Predicate <BindingSet > {
627717 private final Set <BindingSet > distinctValues ;
628718
@@ -884,4 +974,25 @@ public Value apply(BindingSet bindings) {
884974 }
885975 }
886976 }
977+
978+ private static class NAryQueryStepEvaluator implements BiFunction <Integer , BindingSet , Value > {
979+ private final Function <Integer , QueryValueEvaluationStep > evaluationStepFunction ;
980+
981+ public NAryQueryStepEvaluator (Function <Integer , QueryValueEvaluationStep > evaluationStepFunction ) {
982+ this .evaluationStepFunction = evaluationStepFunction ;
983+ }
984+
985+ @ Override
986+ public Value apply (Integer index , BindingSet bindings ) {
987+ try {
988+ return evaluationStepFunction .apply (index ).evaluate (bindings );
989+ } catch (ValueExprEvaluationException e ) {
990+ return null ; // treat missing or invalid expressions as null
991+ }
992+ }
993+
994+ public QueryStepEvaluator asUnaryEvaluator (Integer index ) {
995+ return new QueryStepEvaluator (evaluationStepFunction .apply (index ));
996+ }
997+ }
887998}
0 commit comments