From 1b205a9169fbeefacbd898fa483e83c7dd771f33 Mon Sep 17 00:00:00 2001 From: Nathan Mugerwa Date: Wed, 7 Apr 2021 18:25:36 -0400 Subject: [PATCH] Add COUNT DISTINCT Rewriter --- .../rewriter/CountDistinctRewriter.java | 122 ++++++++++++++++++ .../rewriter/TestCountDistinctRewriter.java | 104 +++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 rewriter/src/main/java/com/facebook/coresql/rewriter/CountDistinctRewriter.java create mode 100644 rewriter/src/test/java/com/facebook/coresql/rewriter/TestCountDistinctRewriter.java diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/CountDistinctRewriter.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/CountDistinctRewriter.java new file mode 100644 index 0000000..759eed6 --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/CountDistinctRewriter.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AggregationFunction; +import com.facebook.coresql.parser.AstNode; +import com.facebook.coresql.parser.Comparison; +import com.facebook.coresql.parser.SqlParserDefaultVisitor; +import com.google.common.collect.ImmutableSet; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTAGGREGATIONFUNCTION; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTIDENTIFIER; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTSETQUANTIFIER; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class CountDistinctRewriter + extends Rewriter +{ + private final AstNode root; + private final Set patternMatchedNodes; + private static final String REPLACEMENT = "MIN(%s) IS DISTINCT FROM MAX(%s)"; + private static final String REWRITE_NAME = "Duplicate Check Using COUNT(DISTINCT x) > 1"; + + public CountDistinctRewriter(AstNode root) + { + this.root = requireNonNull(root, "AST passed to rewriter was null"); + this.patternMatchedNodes = new CountDistinctPatternMatcher(root).matchPattern(); + } + + @Override + public Optional rewrite() + { + if (patternMatchedNodes.isEmpty()) { + return Optional.empty(); + } + String rewrittenSql = unparse(root, this); + return Optional.of(new RewriteResult(REWRITE_NAME, rewrittenSql)); + } + + @Override + public void visit(Comparison node, Void data) + { + if (patternMatchedNodes.contains(node)) { + applyCountDistinctRewrite(node); + } + else { + defaultVisit(node, data); + } + } + + private void applyCountDistinctRewrite(Comparison node) + { + // First, unparse up to the node. This ensures we don't miss any special tokens + unparseUpto(node); + // Then, add the rewritten version to the Unparser + String identifier = unparse(node.GetFirstChildOfKind(JJTAGGREGATIONFUNCTION).GetFirstChildOfKind(JJTIDENTIFIER)).trim(); + printToken(format(REPLACEMENT, identifier, identifier)); + // Move to end of this node -- we've already put in a rewritten version of it, so we don't need to unparse it + moveToEndOfNode(node); + } + + private static class CountDistinctPatternMatcher + extends SqlParserDefaultVisitor + { + private final AstNode root; + private final ImmutableSet.Builder builder = ImmutableSet.builder(); + + public CountDistinctPatternMatcher(AstNode root) + { + this.root = requireNonNull(root, "AST passed to pattern matcher was null"); + } + + public Set matchPattern() + { + root.jjtAccept(this, null); + return builder.build(); + } + + private boolean secondArgIsLiteralOne(Comparison node) + { + Optional secondArg = Optional.ofNullable((AstNode) node.jjtGetChild(1)); + return secondArg.isPresent() && unparse(secondArg.get()).trim().equals("1"); + } + + private boolean aggregationHasCountDistinct(AggregationFunction node) + { + Optional setQuantifier = Optional.ofNullable(node.GetFirstChildOfKind(JJTSETQUANTIFIER)); + return setQuantifier.isPresent() && node.beginToken.image.equalsIgnoreCase("COUNT") && unparse(setQuantifier.get()).equalsIgnoreCase("DISTINCT"); + } + + private boolean isUsingCountDistinctComparisonToCheckUniqueness(Comparison node) + { + Optional aggregationFunction = Optional.ofNullable(node.GetFirstChildOfKind(JJTAGGREGATIONFUNCTION)); + return aggregationFunction.isPresent() && aggregationHasCountDistinct((AggregationFunction) aggregationFunction.get()) && secondArgIsLiteralOne(node); + } + + @Override + public void visit(Comparison node, Void data) + { + if (isUsingCountDistinctComparisonToCheckUniqueness(node)) { + builder.add(node); + } + defaultVisit(node, data); + } + } +} diff --git a/rewriter/src/test/java/com/facebook/coresql/rewriter/TestCountDistinctRewriter.java b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestCountDistinctRewriter.java new file mode 100644 index 0000000..7ef2481 --- /dev/null +++ b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestCountDistinctRewriter.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.coresql.parser.ParserHelper.parseStatement; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestCountDistinctRewriter +{ + private static final String[] STATEMENT_THAT_DOESNT_NEED_REWRITE = new String[] { + "CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT foo FROM T ORDER BY x LIMIT 10) ORDER BY y LIMIT 10) ORDER BY z LIMIT 10;", + "SELECT dealer_id, sales OVER (PARTITION BY dealer_id ORDER BY sales);", + "INSERT INTO blah SELECT * FROM (SELECT t.date, t.code, t.qty FROM sales AS t ORDER BY t.date LIMIT 100);", + "SELECT (true or false) and false;", + "SELECT * FROM T ORDER BY y;", + "SELECT * FROM T ORDER BY y LIMIT 10;", + "use a.b;", + " SELECT 1;", + "SELECT a FROM T;", + "SELECT a FROM T WHERE p1 > p2;", + "SELECT a, b, c FROM T WHERE c1 < c2 and c3 < c4;", + "SELECT CASE a WHEN IN ( 1 ) THEN b ELSE c END AS x, b, c FROM T WHERE c1 < c2 and c3 < c4;", + "SELECT T.* FROM T JOIN W ON T.x = W.x;", + "SELECT NULL;", + "SELECT ARRAY[x] FROM T;", + "SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "CREATE TABLE T AS SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "INSERT INTO T SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "SELECT ROW_NUMBER() OVER(PARTITION BY x) FROM T;", + "SELECT x, SUM(y) OVER (PARTITION BY y ORDER BY 1) AS min\n" + + "FROM (values ('b',10), ('a', 10)) AS T(x, y)\n;", + "SELECT\n" + + " CAST(MAP() AS map>) AS \"bool_tensor_features\";", + "SELECT f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f())))))))))))))))))))))))))))));", + "SELECT abs, 2 as abs;", + // False Positive + "SELECT COUNT(DISTINCT x) > 1.5 FROM T;", + "SELECT COUNT(x) > 1 FROM T;" + }; + + private static final String[] STATEMENT_BEFORE_REWRITE = new String[] { + // True Positive + "SELECT COUNT(DISTINCT x) > 1 AS has_dupes FROM T;", + // Subquery + "SELECT * FROM (SELECT COUNT(DISTINCT x) > 1 AS has_dupes FROM T);" + }; + + private static final String[] STATEMENT_AFTER_REWRITE = new String[] { + // False Positive + "SELECT MIN(x) IS DISTINCT FROM MAX(x) AS has_dupes FROM T;", + "SELECT * FROM (SELECT MIN(x) IS DISTINCT FROM MAX(x) AS has_dupes FROM T);" + }; + + private void assertStatementUnchanged(String originalStatement) + { + assertFalse(getRewriteResult(originalStatement).isPresent()); + } + + private void assertStatementRewritten(String originalStatement, String expectedStatement) + { + Optional result = getRewriteResult(originalStatement); + assertTrue(result.isPresent()); + assertEquals(result.get().getRewrittenSql(), expectedStatement); + } + + private Optional getRewriteResult(String originalStatement) + { + AstNode ast = parseStatement(originalStatement); + assertNotNull(ast); + return new CountDistinctRewriter(ast).rewrite(); + } + + @Test + public void rewriteTest() + { + for (int i = 0; i < STATEMENT_BEFORE_REWRITE.length; i++) { + assertStatementRewritten(STATEMENT_BEFORE_REWRITE[i], STATEMENT_AFTER_REWRITE[i]); + } + + for (String sql : STATEMENT_THAT_DOESNT_NEED_REWRITE) { + assertStatementUnchanged(sql); + } + } +}