Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ private TupleExpr processHavingClause(ASTHavingClause havingNode, TupleExpr tupl

// create an extension linking the operator to the variable
// name.
ExtensionElem pe = new ExtensionElem(operator, alias);
ExtensionElem pe = new ExtensionElem(operator.clone(), alias);
extension.addElement(pe);

// add the aggregate operator to the group.
Expand Down Expand Up @@ -494,7 +494,7 @@ private TupleExpr processOrderClause(ASTOrderClause orderNode, TupleExpr tupleEx
// name.
String alias = var.getName();

ExtensionElem pe = new ExtensionElem(operator, alias);
ExtensionElem pe = new ExtensionElem(operator.clone(), alias);
extension.addElement(pe);

// add the aggregate operator to the group.
Expand Down Expand Up @@ -576,7 +576,7 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException {
Extension anonymousExtension = new Extension();
Var anonVar = createAnonVar();
expr.replaceChildNode(operator, anonVar);
anonymousExtension.addElement(new ExtensionElem(operator, anonVar.getName()));
anonymousExtension.addElement(new ExtensionElem(operator.clone(), anonVar.getName()));

anonymousExtension.setArg(result);
result = anonymousExtension;
Expand All @@ -593,7 +593,7 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException {
// SELECT expressions need to be captured as an extension, so that original and alias are
// available for the ORDER BY clause (which gets applied _before_ projection). See GH-4066
// and https://www.w3.org/TR/sparql11-query/#sparqlSolMod .
ExtensionElem extElem = new ExtensionElem(valueExpr, alias);
ExtensionElem extElem = new ExtensionElem(cloneIfAggregate(valueExpr), alias);
extension.addElement(extElem);
elem.setSourceExpression(extElem);
} else if (child instanceof ASTVar) {
Expand Down Expand Up @@ -663,6 +663,13 @@ public TupleExpr visit(ASTSelect node, Object data) throws VisitorException {
return result;
}

private ValueExpr cloneIfAggregate(ValueExpr valueExpr) {
if (valueExpr instanceof AggregateOperator) {
return ((AggregateOperator) valueExpr).clone();
}
return valueExpr;
}

private static boolean isIllegalCombinedWithGroupByExpression(ValueExpr expr, List<ProjectionElem> elements,
Set<String> groupNames) {
if (expr instanceof ValueConstant) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,27 @@
import static org.junit.jupiter.api.Assertions.fail;

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;

import org.eclipse.rdf4j.model.impl.SimpleValueFactory;
import org.eclipse.rdf4j.query.algebra.AggregateOperator;
import org.eclipse.rdf4j.query.algebra.Count;
import org.eclipse.rdf4j.query.algebra.Extension;
import org.eclipse.rdf4j.query.algebra.ExtensionElem;
import org.eclipse.rdf4j.query.algebra.Filter;
import org.eclipse.rdf4j.query.algebra.Group;
import org.eclipse.rdf4j.query.algebra.GroupElem;
import org.eclipse.rdf4j.query.algebra.Order;
import org.eclipse.rdf4j.query.algebra.Projection;
import org.eclipse.rdf4j.query.algebra.Service;
import org.eclipse.rdf4j.query.algebra.SingletonSet;
import org.eclipse.rdf4j.query.algebra.Slice;
import org.eclipse.rdf4j.query.algebra.TupleExpr;
import org.eclipse.rdf4j.query.algebra.ValueExpr;
import org.eclipse.rdf4j.query.algebra.Var;
import org.eclipse.rdf4j.query.algebra.helpers.AbstractQueryModelVisitor;
import org.eclipse.rdf4j.query.parser.sparql.ast.ASTQueryContainer;
import org.eclipse.rdf4j.query.parser.sparql.ast.ASTServiceGraphPattern;
import org.eclipse.rdf4j.query.parser.sparql.ast.ASTUpdateSequence;
Expand Down Expand Up @@ -74,6 +84,113 @@ public void testSimpleAliasHandling() {
}
}

@Test
public void testAggregateProjectionParentReferences() throws Exception {
String query = "SELECT (COUNT(?s) AS ?count) WHERE { ?s ?p ?o }";

ASTQueryContainer qc = SyntaxTreeBuilder.parseQuery(query);
TupleExpr tupleExpr = builder.visit(qc, null);

assertThat(tupleExpr).isInstanceOf(Projection.class);
Projection projection = (Projection) tupleExpr;

assertThat(projection.getArg()).isInstanceOf(Extension.class);
Extension extension = (Extension) projection.getArg();

assertThat(extension.getArg()).isInstanceOf(Group.class);
Group group = (Group) extension.getArg();

assertThat(group.getGroupElements()).hasSize(1);
GroupElem groupElem = group.getGroupElements().get(0);
AggregateOperator operator = groupElem.getOperator();

assertThat(operator).isInstanceOf(Count.class);
assertThat(operator.getParentNode()).isSameAs(groupElem);

assertThat(extension.getElements()).hasSize(1);
ExtensionElem extensionElem = extension.getElements().get(0);
ValueExpr extExpr = extensionElem.getExpr();

assertThat(extExpr).isInstanceOf(Count.class);
assertThat(extExpr.getParentNode()).isSameAs(extensionElem);
assertThat(extExpr).isNotSameAs(operator);
}

@Test
public void testAggregateOrderByParentReferences() throws Exception {
String query = "SELECT (COUNT(?s) AS ?count) WHERE { ?s ?p ?o } ORDER BY (COUNT(?s))";

ASTQueryContainer qc = SyntaxTreeBuilder.parseQuery(query);
TupleExpr tupleExpr = builder.visit(qc, null);

AggregateOperatorContext context = collectAggregateOperators(tupleExpr);

assertThat(context.groupOperators).isNotEmpty();
assertThat(context.extensionOperators).hasSizeGreaterThanOrEqualTo(2);

for (AggregateOperator operator : context.groupOperators) {
assertThat(operator.getParentNode()).isInstanceOf(GroupElem.class);
}

for (AggregateOperator operator : context.extensionOperators) {
assertThat(operator.getParentNode()).isInstanceOf(ExtensionElem.class);
assertThat(context.containsSameInstanceInGroup(operator)).isFalse();
}
}

@Test
public void testAggregateHavingParentReferences() throws Exception {
String query = "SELECT (COUNT(?s) AS ?count) WHERE { ?s ?p ?o } HAVING (COUNT(?s) > 1)";

ASTQueryContainer qc = SyntaxTreeBuilder.parseQuery(query);
TupleExpr tupleExpr = builder.visit(qc, null);

AggregateOperatorContext context = collectAggregateOperators(tupleExpr);

assertThat(context.groupOperators).isNotEmpty();
assertThat(context.extensionOperators).hasSizeGreaterThanOrEqualTo(2);

for (AggregateOperator operator : context.groupOperators) {
assertThat(operator.getParentNode()).isInstanceOf(GroupElem.class);
}

for (AggregateOperator operator : context.extensionOperators) {
assertThat(operator.getParentNode()).isInstanceOf(ExtensionElem.class);
assertThat(context.containsSameInstanceInGroup(operator)).isFalse();
}

Filter filter = findNode(tupleExpr, Filter.class);
assertThat(filter).isNotNull();
}

@Test
public void testAggregateGroupConditionParentReferences() throws Exception {
String query = "SELECT (COUNT(?s) AS ?count) WHERE { ?s ?p ?o } GROUP BY (COUNT(?s) AS ?groupCount)";

ASTQueryContainer qc = SyntaxTreeBuilder.parseQuery(query);
TupleExpr tupleExpr = builder.visit(qc, null);

AggregateOperatorContext context = collectAggregateOperators(tupleExpr);
assertThat(context.groupOperators).isNotEmpty();
assertThat(context.extensionOperators).isNotEmpty();

for (AggregateOperator operator : context.groupOperators) {
assertThat(operator.getParentNode()).isInstanceOf(GroupElem.class);
}

for (AggregateOperator operator : context.extensionOperators) {
assertThat(operator.getParentNode()).isInstanceOf(ExtensionElem.class);
assertThat(context.containsSameInstanceInGroup(operator)).isFalse();
}

ExtensionElem groupAliasExtension = findExtensionElem(tupleExpr, "groupCount");
assertThat(groupAliasExtension).isNotNull();
assertThat(groupAliasExtension.getExpr()).isInstanceOf(AggregateOperator.class);
AggregateOperator groupAliasOperator = (AggregateOperator) groupAliasExtension.getExpr();
assertThat(groupAliasOperator.getParentNode()).isSameAs(groupAliasExtension);
assertThat(context.containsSameInstanceInGroup(groupAliasOperator)).isFalse();
}

@Test
public void testBindVarReuseHandling() {
String query = "SELECT * WHERE { ?s ?p ?o. BIND(<foo:bar> as ?o) }";
Expand Down Expand Up @@ -315,4 +432,78 @@ public List<String> getGraphPatterns() {
return graphPatterns;
}
}

private AggregateOperatorContext collectAggregateOperators(TupleExpr tupleExpr) {
AggregateOperatorContext context = new AggregateOperatorContext();
tupleExpr.visit(new AbstractQueryModelVisitor<RuntimeException>() {
@Override
public void meet(GroupElem node) {
AggregateOperator operator = node.getOperator();
context.groupOperators.add(operator);
context.groupIdentities.put(operator, Boolean.TRUE);
super.meet(node);
}

@Override
public void meet(ExtensionElem node) {
ValueExpr expr = node.getExpr();
if (expr instanceof AggregateOperator) {
context.extensionOperators.add((AggregateOperator) expr);
}
super.meet(node);
}
});
return context;
}

private <T> T findNode(TupleExpr tupleExpr, Class<T> type) {
class Finder extends AbstractQueryModelVisitor<RuntimeException> {
private T result;

@Override
protected void meetNode(org.eclipse.rdf4j.query.algebra.QueryModelNode node) {
if (result != null) {
return;
}
if (type.isInstance(node)) {
result = type.cast(node);
} else {
super.meetNode(node);
}
}
}

Finder finder = new Finder();
tupleExpr.visit(finder);
return finder.result;
}

private ExtensionElem findExtensionElem(TupleExpr tupleExpr, String name) {
class Finder extends AbstractQueryModelVisitor<RuntimeException> {
private ExtensionElem result;

@Override
public void meet(ExtensionElem node) {
if (result == null && name.equals(node.getName())) {
result = node;
return;
}
super.meet(node);
}
}

Finder finder = new Finder();
tupleExpr.visit(finder);
return finder.result;
}

private static final class AggregateOperatorContext {
private final List<AggregateOperator> groupOperators = new ArrayList<>();
private final Map<AggregateOperator, Boolean> groupIdentities = new IdentityHashMap<>();
private final List<AggregateOperator> extensionOperators = new ArrayList<>();

boolean containsSameInstanceInGroup(AggregateOperator operator) {
return groupIdentities.containsKey(operator);
}
}
}
Loading