Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*
* SPDX-License-Identifier: BSD-3-Clause
*******************************************************************************/
// Some portions generated by Codex
package org.eclipse.rdf4j.query.algebra.evaluation.iterator;

import java.util.ArrayList;
Expand All @@ -19,6 +20,7 @@
import java.util.Random;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
Expand Down Expand Up @@ -50,6 +52,7 @@
import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp;
import org.eclipse.rdf4j.query.algebra.Max;
import org.eclipse.rdf4j.query.algebra.Min;
import org.eclipse.rdf4j.query.algebra.NAryValueOperator;
import org.eclipse.rdf4j.query.algebra.Sample;
import org.eclipse.rdf4j.query.algebra.Sum;
import org.eclipse.rdf4j.query.algebra.UnaryValueOperator;
Expand All @@ -65,7 +68,11 @@
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateCollector;
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunction;
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateFunctionFactory;
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateNAryFunctionFactory;
import org.eclipse.rdf4j.query.parser.sparql.aggregate.AggregateProcessor;
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateFunctionRegistry;
import org.eclipse.rdf4j.query.parser.sparql.aggregate.CustomAggregateNAryFunctionRegistry;

/**
* @author David Huynh
Expand Down Expand Up @@ -349,7 +356,7 @@ private List<Entry> emptySolutionSpecialCase(List<AggregatePredicateCollectorSup
// Even in the case that the Count is of a constant value.
predicates.add(ALWAYS_FALSE_VALUE);
} else {
predicates.add(ALWAYS_TRUE_VALUE);
predicates.add(ag.makePotentialDistinctTest.get());
}
}
final Entry entry = new Entry(null, collectors, predicates);
Expand Down Expand Up @@ -409,11 +416,11 @@ public long getSize() {
*/
private static class AggregatePredicateCollectorSupplier<T extends AggregateCollector, D> {
public final String name;
private final AggregateFunction<T, D> agg;
private final AggregateProcessor<T, D> agg;
private final Supplier<Predicate<D>> makePotentialDistinctTest;
private final Supplier<T> makeAggregateCollector;

public AggregatePredicateCollectorSupplier(AggregateFunction<T, D> agg,
public AggregatePredicateCollectorSupplier(AggregateProcessor<T, D> agg,
Supplier<Predicate<D>> makePotentialDistinctTest, Supplier<T> makeAggregateCollector, String name) {
super();
this.agg = agg;
Expand All @@ -431,6 +438,7 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
private static final Predicate<Value> ALWAYS_TRUE_VALUE = t -> true;
private static final Predicate<Value> ALWAYS_FALSE_VALUE = t -> false;
private static final Supplier<Predicate<Value>> ALWAYS_TRUE_VALUE_SUPPLIER = () -> ALWAYS_TRUE_VALUE;
private static final Supplier<Predicate<List<Value>>> ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER = () -> t -> true;

private AggregatePredicateCollectorSupplier<?, ?> create(GroupElem ge, ValueFactory vf)
throws QueryEvaluationException {
Expand All @@ -444,57 +452,68 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
return new AggregatePredicateCollectorSupplier<>(wildCardCountAggregate, potentialDistinctTest,
() -> new CountCollector(vf), ge.getName());
} else {
QueryStepEvaluator f = precompileArg(operator);
QueryStepEvaluator f = precompileUnaryArg(operator);
CountAggregate agg = new CountAggregate(f);
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new CountCollector(vf),
ge.getName());
}
} else if (operator instanceof Min) {
MinAggregate agg = new MinAggregate(precompileArg(operator), shouldValueComparisonBeStrict());
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
MinAggregate agg = new MinAggregate(precompileUnaryArg(operator), shouldValueComparisonBeStrict());
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
return new AggregatePredicateCollectorSupplier<>(agg, predicate, ValueCollector::new, ge.getName());
} else if (operator instanceof Max) {
MaxAggregate agg = new MaxAggregate(precompileArg(operator), shouldValueComparisonBeStrict());
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
MaxAggregate agg = new MaxAggregate(precompileUnaryArg(operator), shouldValueComparisonBeStrict());
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
return new AggregatePredicateCollectorSupplier<>(agg, predicate, ValueCollector::new, ge.getName());
} else if (operator instanceof Sum) {

SumAggregate agg = new SumAggregate(precompileArg(operator));
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
SumAggregate agg = new SumAggregate(precompileUnaryArg(operator));
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new IntegerCollector(vf),
ge.getName());
} else if (operator instanceof Avg) {
AvgAggregate agg = new AvgAggregate(precompileArg(operator));
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
AvgAggregate agg = new AvgAggregate(precompileUnaryArg(operator));
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new AvgCollector(vf), ge.getName());
} else if (operator instanceof Sample) {
SampleAggregate agg = new SampleAggregate(precompileArg(operator));
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
SampleAggregate agg = new SampleAggregate(precompileUnaryArg(operator));
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
return new AggregatePredicateCollectorSupplier<>(agg, predicate, SampleCollector::new, ge.getName());
} else if (operator instanceof GroupConcat) {
ValueExpr separatorExpr = ((GroupConcat) operator).getSeparator();
ConcatAggregate agg;
if (separatorExpr != null) {
Value separatorValue = strategy.evaluate(separatorExpr, parentBindings);
agg = new ConcatAggregate(precompileArg(operator), separatorValue.stringValue());
agg = new ConcatAggregate(precompileUnaryArg(operator), separatorValue.stringValue());
} else {
agg = new ConcatAggregate(precompileArg(operator));
agg = new ConcatAggregate(precompileUnaryArg(operator));
}
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
return new AggregatePredicateCollectorSupplier<>(agg, predicate, () -> new StringBuilderCollector(vf),
ge.getName());
} else if (operator instanceof AggregateFunctionCall) {
var aggOperator = (AggregateFunctionCall) operator;
Supplier<Predicate<Value>> predicate = createDistinctValueTest(operator);
var factory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI());
int argumentCount = aggOperator.getArguments().size();
var nAryFactory = CustomAggregateNAryFunctionRegistry.getInstance().get(aggOperator.getIRI());
var unaryFactory = CustomAggregateFunctionRegistry.getInstance().get(aggOperator.getIRI());

var function = factory.orElseThrow(
() -> new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'"))
.buildFunction(precompileArg(aggOperator));
return new AggregatePredicateCollectorSupplier<>(function, predicate, () -> factory.get().getCollector(),
ge.getName());
if (argumentCount == 1 && unaryFactory.isPresent()) {
return createUnaryCustomAggregate(ge, operator, aggOperator, unaryFactory.get());
}
if (nAryFactory.isPresent()) {
validateNAryAggregateArity(aggOperator, nAryFactory.get(), argumentCount);
return createNAryCustomAggregate(ge, operator, aggOperator, nAryFactory.get());
}
if (unaryFactory.isPresent()) {
if (argumentCount != 1) {
throw new QueryEvaluationException("Custom unary aggregate function '" + aggOperator.getIRI()
+ "' expects exactly 1 argument, got " + argumentCount);
}
return createUnaryCustomAggregate(ge, operator, aggOperator, unaryFactory.get());
}

throw new QueryEvaluationException("Unknown aggregate function '" + aggOperator.getIRI() + "'");
}

return null;
Expand All @@ -508,15 +527,73 @@ private void operate(BindingSet bs, Predicate<?> predicate, Object t) {
* @return a supplier that returns a predicate that tests if the value is distinct, or always returns true if the
* operator is not distinct.
*/
private Supplier<Predicate<Value>> createDistinctValueTest(AggregateOperator operator) {
private Supplier<Predicate<Value>> createDistinctSingleValueTest(AggregateOperator operator) {
return operator.isDistinct() ? DistinctValues::new : ALWAYS_TRUE_VALUE_SUPPLIER;
}

private QueryStepEvaluator precompileArg(AggregateOperator operator) {
/**
* Create a predicate that tests if the tuple of values is distinct (returning true if the tuple was not seen yet),
* or always returns true if the operator is not distinct.
*
* @param operator
* @return a supplier that returns a predicate that tests if the tuple of values is distinct, or always returns true
* if the operator is not distinct.
*/
private Supplier<Predicate<List<Value>>> createDistinctTupleValueTest(AggregateOperator operator) {
return operator.isDistinct() ? DistinctTupleValues::new : ALWAYS_TRUE_TUPLE_VALUE_SUPPLIER;
}

private AggregatePredicateCollectorSupplier<?, ?> createUnaryCustomAggregate(GroupElem ge,
AggregateOperator operator,
AggregateFunctionCall aggOperator, AggregateFunctionFactory factory) {
Supplier<Predicate<Value>> predicate = createDistinctSingleValueTest(operator);
AggregateFunction function = factory.buildFunction(precompileNAryArg(aggOperator).asUnaryEvaluator(0));
return new AggregatePredicateCollectorSupplier<>(function, predicate, factory::getCollector, ge.getName());
}

private AggregatePredicateCollectorSupplier<?, ?> createNAryCustomAggregate(GroupElem ge,
AggregateOperator operator,
AggregateFunctionCall aggOperator, AggregateNAryFunctionFactory factory) {
Supplier<Predicate<List<Value>>> predicate = createDistinctTupleValueTest(operator);
var function = factory.buildFunction(precompileNAryArg(aggOperator));
return new AggregatePredicateCollectorSupplier<>(function, predicate, factory::getCollector, ge.getName());
}

private void validateNAryAggregateArity(AggregateFunctionCall aggregateOperator,
AggregateNAryFunctionFactory factory,
int argumentCount) {
int minimum = factory.getMinNumberOfArguments();
int maximum = factory.getMaxNumberOfArguments();
if (minimum < 0 || maximum < minimum) {
throw new QueryEvaluationException("Custom n-ary aggregate function '" + aggregateOperator.getIRI()
+ "' has invalid arity declaration");
}
if (argumentCount < minimum) {
throw new QueryEvaluationException(
"Custom n-ary aggregate function '" + aggregateOperator.getIRI() + "' expects at least " + minimum
+ " arguments, got " + argumentCount);
}
if (argumentCount > maximum) {
throw new QueryEvaluationException(
"Custom n-ary aggregate function '" + aggregateOperator.getIRI() + "' expects at most " + maximum
+ " arguments, got " + argumentCount);
}
}

private QueryStepEvaluator precompileUnaryArg(AggregateOperator operator) {
ValueExpr ve = ((UnaryValueOperator) operator).getArg();
return new QueryStepEvaluator(strategy.precompile(ve, context));
}

private NAryQueryStepEvaluator precompileNAryArg(AggregateOperator operator) {
List<ValueExpr> args = ((NAryValueOperator) operator).getArguments();
List<QueryValueEvaluationStep> precompiledArgs = new ArrayList<>(args.size());
for (ValueExpr arg : args) {
precompiledArgs.add(strategy.precompile(arg, context));
}
return new NAryQueryStepEvaluator(precompiledArgs::get);
}

private boolean shouldValueComparisonBeStrict() {
return strategy.getQueryEvaluationMode() == QueryEvaluationMode.STRICT;
}
Expand Down Expand Up @@ -623,6 +700,19 @@ public boolean test(Value value) {
}
}

private class DistinctTupleValues implements Predicate<List<Value>> {
private final Set<List<Value>> distinctTuples;

public DistinctTupleValues() {
distinctTuples = cf.createSet();
}

@Override
public boolean test(List<Value> valueTuple) {
return distinctTuples.add(valueTuple);
}
}

private class DistinctBindingSets implements Predicate<BindingSet> {
private final Set<BindingSet> distinctValues;

Expand Down Expand Up @@ -884,4 +974,25 @@ public Value apply(BindingSet bindings) {
}
}
}

private static class NAryQueryStepEvaluator implements BiFunction<Integer, BindingSet, Value> {
private final Function<Integer, QueryValueEvaluationStep> evaluationStepFunction;

public NAryQueryStepEvaluator(Function<Integer, QueryValueEvaluationStep> evaluationStepFunction) {
this.evaluationStepFunction = evaluationStepFunction;
}

@Override
public Value apply(Integer index, BindingSet bindings) {
try {
return evaluationStepFunction.apply(index).evaluate(bindings);
} catch (ValueExprEvaluationException e) {
return null; // treat missing or invalid expressions as null
}
}

public QueryStepEvaluator asUnaryEvaluator(Integer index) {
return new QueryStepEvaluator(evaluationStepFunction.apply(index));
}
}
}
Loading
Loading