diff --git a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java index a52d72d8871..0a4a50a5ea7 100644 --- a/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java +++ b/core/queryparser/sparql/src/main/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilder.java @@ -449,7 +449,7 @@ private TupleExpr processHavingClause(ASTHavingClause havingNode, TupleExpr tupl // create an extension linking the operator to the variable // name. - ExtensionElem pe = new ExtensionElem(operator, alias); + ExtensionElem pe = new ExtensionElem(operator.clone(), alias); extension.addElement(pe); // add the aggregate operator to the group. @@ -494,7 +494,7 @@ private TupleExpr processOrderClause(ASTOrderClause orderNode, TupleExpr tupleEx // name. String alias = var.getName(); - ExtensionElem pe = new ExtensionElem(operator, alias); + ExtensionElem pe = new ExtensionElem(operator.clone(), alias); extension.addElement(pe); // add the aggregate operator to the group. @@ -576,7 +576,7 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException { Extension anonymousExtension = new Extension(); Var anonVar = createAnonVar(); expr.replaceChildNode(operator, anonVar); - anonymousExtension.addElement(new ExtensionElem(operator, anonVar.getName())); + anonymousExtension.addElement(new ExtensionElem(operator.clone(), anonVar.getName())); anonymousExtension.setArg(result); result = anonymousExtension; @@ -593,7 +593,7 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException { // SELECT expressions need to be captured as an extension, so that original and alias are // available for the ORDER BY clause (which gets applied _before_ projection). See GH-4066 // and https://www.w3.org/TR/sparql11-query/#sparqlSolMod . - ExtensionElem extElem = new ExtensionElem(valueExpr, alias); + ExtensionElem extElem = new ExtensionElem(cloneIfAggregate(valueExpr), alias); extension.addElement(extElem); elem.setSourceExpression(extElem); } else if (child instanceof ASTVar) { @@ -663,6 +663,13 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException { return result; } + private ValueExpr cloneIfAggregate(ValueExpr valueExpr) { + if (valueExpr instanceof AggregateOperator) { + return ((AggregateOperator) valueExpr).clone(); + } + return valueExpr; + } + private static boolean isIllegalCombinedWithGroupByExpression(ValueExpr expr, List elements, Set groupNames) { if (expr instanceof ValueConstant) { diff --git a/core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilderTest.java b/core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilderTest.java index 0ebea2524dc..b40b808fd57 100644 --- a/core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilderTest.java +++ b/core/queryparser/sparql/src/test/java/org/eclipse/rdf4j/query/parser/sparql/TupleExprBuilderTest.java @@ -17,17 +17,27 @@ import static org.junit.jupiter.api.Assertions.fail; import java.util.ArrayList; +import java.util.IdentityHashMap; import java.util.List; +import java.util.Map; import org.eclipse.rdf4j.model.impl.SimpleValueFactory; +import org.eclipse.rdf4j.query.algebra.AggregateOperator; +import org.eclipse.rdf4j.query.algebra.Count; import org.eclipse.rdf4j.query.algebra.Extension; +import org.eclipse.rdf4j.query.algebra.ExtensionElem; +import org.eclipse.rdf4j.query.algebra.Filter; +import org.eclipse.rdf4j.query.algebra.Group; +import org.eclipse.rdf4j.query.algebra.GroupElem; import org.eclipse.rdf4j.query.algebra.Order; import org.eclipse.rdf4j.query.algebra.Projection; import org.eclipse.rdf4j.query.algebra.Service; import org.eclipse.rdf4j.query.algebra.SingletonSet; import org.eclipse.rdf4j.query.algebra.Slice; import org.eclipse.rdf4j.query.algebra.TupleExpr; +import org.eclipse.rdf4j.query.algebra.ValueExpr; import org.eclipse.rdf4j.query.algebra.Var; +import org.eclipse.rdf4j.query.algebra.helpers.AbstractQueryModelVisitor; import org.eclipse.rdf4j.query.parser.sparql.ast.ASTQueryContainer; import org.eclipse.rdf4j.query.parser.sparql.ast.ASTServiceGraphPattern; import org.eclipse.rdf4j.query.parser.sparql.ast.ASTUpdateSequence; @@ -74,6 +84,113 @@ public void testSimpleAliasHandling() { } } + @Test + public void testAggregateProjectionParentReferences() throws Exception { + String query = "SELECT (COUNT(?s) AS ?count) WHERE { ?s ?p ?o }"; + + ASTQueryContainer qc = SyntaxTreeBuilder.parseQuery(query); + TupleExpr tupleExpr = builder.visit(qc, null); + + assertThat(tupleExpr).isInstanceOf(Projection.class); + Projection projection = (Projection) tupleExpr; + + assertThat(projection.getArg()).isInstanceOf(Extension.class); + Extension extension = (Extension) projection.getArg(); + + assertThat(extension.getArg()).isInstanceOf(Group.class); + Group group = (Group) extension.getArg(); + + assertThat(group.getGroupElements()).hasSize(1); + GroupElem groupElem = group.getGroupElements().get(0); + AggregateOperator operator = groupElem.getOperator(); + + assertThat(operator).isInstanceOf(Count.class); + assertThat(operator.getParentNode()).isSameAs(groupElem); + + assertThat(extension.getElements()).hasSize(1); + ExtensionElem extensionElem = extension.getElements().get(0); + ValueExpr extExpr = extensionElem.getExpr(); + + assertThat(extExpr).isInstanceOf(Count.class); + assertThat(extExpr.getParentNode()).isSameAs(extensionElem); + assertThat(extExpr).isNotSameAs(operator); + } + + @Test + public void testAggregateOrderByParentReferences() throws Exception { + String query = "SELECT (COUNT(?s) AS ?count) WHERE { ?s ?p ?o } ORDER BY (COUNT(?s))"; + + ASTQueryContainer qc = SyntaxTreeBuilder.parseQuery(query); + TupleExpr tupleExpr = builder.visit(qc, null); + + AggregateOperatorContext context = collectAggregateOperators(tupleExpr); + + assertThat(context.groupOperators).isNotEmpty(); + assertThat(context.extensionOperators).hasSizeGreaterThanOrEqualTo(2); + + for (AggregateOperator operator : context.groupOperators) { + assertThat(operator.getParentNode()).isInstanceOf(GroupElem.class); + } + + for (AggregateOperator operator : context.extensionOperators) { + assertThat(operator.getParentNode()).isInstanceOf(ExtensionElem.class); + assertThat(context.containsSameInstanceInGroup(operator)).isFalse(); + } + } + + @Test + public void testAggregateHavingParentReferences() throws Exception { + String query = "SELECT (COUNT(?s) AS ?count) WHERE { ?s ?p ?o } HAVING (COUNT(?s) > 1)"; + + ASTQueryContainer qc = SyntaxTreeBuilder.parseQuery(query); + TupleExpr tupleExpr = builder.visit(qc, null); + + AggregateOperatorContext context = collectAggregateOperators(tupleExpr); + + assertThat(context.groupOperators).isNotEmpty(); + assertThat(context.extensionOperators).hasSizeGreaterThanOrEqualTo(2); + + for (AggregateOperator operator : context.groupOperators) { + assertThat(operator.getParentNode()).isInstanceOf(GroupElem.class); + } + + for (AggregateOperator operator : context.extensionOperators) { + assertThat(operator.getParentNode()).isInstanceOf(ExtensionElem.class); + assertThat(context.containsSameInstanceInGroup(operator)).isFalse(); + } + + Filter filter = findNode(tupleExpr, Filter.class); + assertThat(filter).isNotNull(); + } + + @Test + public void testAggregateGroupConditionParentReferences() throws Exception { + String query = "SELECT (COUNT(?s) AS ?count) WHERE { ?s ?p ?o } GROUP BY (COUNT(?s) AS ?groupCount)"; + + ASTQueryContainer qc = SyntaxTreeBuilder.parseQuery(query); + TupleExpr tupleExpr = builder.visit(qc, null); + + AggregateOperatorContext context = collectAggregateOperators(tupleExpr); + assertThat(context.groupOperators).isNotEmpty(); + assertThat(context.extensionOperators).isNotEmpty(); + + for (AggregateOperator operator : context.groupOperators) { + assertThat(operator.getParentNode()).isInstanceOf(GroupElem.class); + } + + for (AggregateOperator operator : context.extensionOperators) { + assertThat(operator.getParentNode()).isInstanceOf(ExtensionElem.class); + assertThat(context.containsSameInstanceInGroup(operator)).isFalse(); + } + + ExtensionElem groupAliasExtension = findExtensionElem(tupleExpr, "groupCount"); + assertThat(groupAliasExtension).isNotNull(); + assertThat(groupAliasExtension.getExpr()).isInstanceOf(AggregateOperator.class); + AggregateOperator groupAliasOperator = (AggregateOperator) groupAliasExtension.getExpr(); + assertThat(groupAliasOperator.getParentNode()).isSameAs(groupAliasExtension); + assertThat(context.containsSameInstanceInGroup(groupAliasOperator)).isFalse(); + } + @Test public void testBindVarReuseHandling() { String query = "SELECT * WHERE { ?s ?p ?o. BIND( as ?o) }"; @@ -315,4 +432,78 @@ public List getGraphPatterns() { return graphPatterns; } } + + private AggregateOperatorContext collectAggregateOperators(TupleExpr tupleExpr) { + AggregateOperatorContext context = new AggregateOperatorContext(); + tupleExpr.visit(new AbstractQueryModelVisitor() { + @Override + public void meet(GroupElem node) { + AggregateOperator operator = node.getOperator(); + context.groupOperators.add(operator); + context.groupIdentities.put(operator, Boolean.TRUE); + super.meet(node); + } + + @Override + public void meet(ExtensionElem node) { + ValueExpr expr = node.getExpr(); + if (expr instanceof AggregateOperator) { + context.extensionOperators.add((AggregateOperator) expr); + } + super.meet(node); + } + }); + return context; + } + + private T findNode(TupleExpr tupleExpr, Class type) { + class Finder extends AbstractQueryModelVisitor { + private T result; + + @Override + protected void meetNode(org.eclipse.rdf4j.query.algebra.QueryModelNode node) { + if (result != null) { + return; + } + if (type.isInstance(node)) { + result = type.cast(node); + } else { + super.meetNode(node); + } + } + } + + Finder finder = new Finder(); + tupleExpr.visit(finder); + return finder.result; + } + + private ExtensionElem findExtensionElem(TupleExpr tupleExpr, String name) { + class Finder extends AbstractQueryModelVisitor { + private ExtensionElem result; + + @Override + public void meet(ExtensionElem node) { + if (result == null && name.equals(node.getName())) { + result = node; + return; + } + super.meet(node); + } + } + + Finder finder = new Finder(); + tupleExpr.visit(finder); + return finder.result; + } + + private static final class AggregateOperatorContext { + private final List groupOperators = new ArrayList<>(); + private final Map groupIdentities = new IdentityHashMap<>(); + private final List extensionOperators = new ArrayList<>(); + + boolean containsSameInstanceInGroup(AggregateOperator operator) { + return groupIdentities.containsKey(operator); + } + } }