diff --git a/pom.xml b/pom.xml index fad61bc..6fc29f3 100644 --- a/pom.xml +++ b/pom.xml @@ -19,6 +19,7 @@ ${project.parent.basedir} + 0.198 @@ -27,6 +28,28 @@ rewriter + + + + com.facebook.airlift + configuration + ${dep.airlift.version} + + + + com.facebook.airlift + log + ${dep.airlift.version} + + + + com.facebook.airlift + bootstrap + ${dep.airlift.version} + + + + diff --git a/rewriter/pom.xml b/rewriter/pom.xml index c04330c..c8ada2e 100644 --- a/rewriter/pom.xml +++ b/rewriter/pom.xml @@ -46,6 +46,32 @@ com.google.guava guava + + + javax.validation + validation-api + 2.0.1.Final + + + + com.facebook.airlift + configuration + + + + com.facebook.airlift + log + + + + com.facebook.airlift + bootstrap + + + + com.google.inject + guice + diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriver.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriver.java new file mode 100644 index 0000000..a66914d --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriver.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.log.Logger; +import com.facebook.coresql.parser.AstNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.coresql.parser.ParserHelper.parseStatement; +import static java.util.Objects.requireNonNull; + +public class RewriteDriver +{ + private static final Logger LOG = Logger.get(RewriteDriver.class); + private static final Set> KNOWN_REWRITERS = ImmutableSet.of(OrderByRewriter.class); + private final Set> userEnabledRewriters; + private AstNode root; + + public RewriteDriver(RewriteDriverConfig config, AstNode root) + { + this.root = requireNonNull(root, "AST given to driver was null"); + this.userEnabledRewriters = config.getUserEnabledRewriters(); + new Bootstrap(new RewriteDriverModule()).doNotInitializeLogging().quiet().initialize(); + } + + public Optional applyRewriters() + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (Class rewriter : userEnabledRewriters) { + if (!KNOWN_REWRITERS.contains(rewriter)) { + LOG.error("An unknown rewriter was passed to rewrite driver: %s", rewriter.getName()); + return Optional.empty(); + } + Rewriter rewriterInstance; + try { + rewriterInstance = rewriter.getConstructor(AstNode.class).newInstance(root); + } + catch (Exception e) { + LOG.error(e, "Exception thrown while creating an instance of this rewriter: %s", rewriter.getName()); + continue; + } + Optional result = rewriterInstance.rewrite(); + if (result.isPresent()) { + builder.add(result.get()); + root = parseStatement(result.get().getRewrittenSql()); + } + } + List results = builder.build(); + return results.isEmpty() ? Optional.empty() : Optional.of(new RewriteDriverResult(results, root)); + } +} diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverConfig.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverConfig.java new file mode 100644 index 0000000..2db396e --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverConfig.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.airlift.configuration.Config; +import com.google.common.collect.ImmutableSet; + +import javax.validation.constraints.NotNull; + +import java.util.Set; + +import static java.util.Collections.emptySet; + +public class RewriteDriverConfig +{ + private Set> userEnabledRewriters = emptySet(); + + @Config("user-enabled-rewriters") + public RewriteDriverConfig setUserEnabledRewriters(Set> userEnabledRewriters) + { + this.userEnabledRewriters = ImmutableSet.copyOf(userEnabledRewriters); + return this; + } + + @NotNull + public Set> getUserEnabledRewriters() + { + return userEnabledRewriters; + } +} diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverModule.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverModule.java new file mode 100644 index 0000000..47345d1 --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverModule.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.google.inject.Binder; +import com.google.inject.Module; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; + +public class RewriteDriverModule + implements Module +{ + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(RewriteDriverConfig.class); + } +} diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverResult.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverResult.java new file mode 100644 index 0000000..28bbf6d --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteDriverResult.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class RewriteDriverResult +{ + private final List rewriteResults; + private final AstNode rewrittenSqlAst; + + public RewriteDriverResult(List rewriteResults, AstNode rewrittenSqlAsAst) + { + this.rewriteResults = requireNonNull(rewriteResults, "list of rewrite results is null"); + this.rewrittenSqlAst = requireNonNull(rewrittenSqlAsAst, "rewritten sql ast is null"); + } + + public List getRewriteResults() + { + return ImmutableList.copyOf(rewriteResults); + } + + public AstNode getRewrittenSqlAsAst() + { + return rewrittenSqlAst; + } +} diff --git a/rewriter/src/test/java/com/facebook/coresql/rewriter/TestRewriteDriver.java b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestRewriteDriver.java new file mode 100644 index 0000000..7e486d8 --- /dev/null +++ b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestRewriteDriver.java @@ -0,0 +1,127 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.coresql.parser.ParserHelper.parseStatement; +import static com.facebook.coresql.parser.Unparser.unparse; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestRewriteDriver +{ + private static final String[] STATEMENT_THAT_DOESNT_NEED_REWRITE = new String[] { + "CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT foo FROM T ORDER BY x LIMIT 10) ORDER BY y LIMIT 10) ORDER BY z LIMIT 10;", + "SELECT dealer_id, sales OVER (PARTITION BY dealer_id ORDER BY sales);", + "INSERT INTO blah SELECT * FROM (SELECT t.date, t.code, t.qty FROM sales AS t ORDER BY t.date LIMIT 100);", + "SELECT (true or false) and false;", + "SELECT * FROM T ORDER BY y;", + "SELECT * FROM T ORDER BY y LIMIT 10;", + "use a.b;", + "SELECT 1;", + "SELECT a FROM T;", + "SELECT a FROM T WHERE p1 > p2;", + "SELECT a, b, c FROM T WHERE c1 < c2 and c3 < c4;", + "SELECT CASE a WHEN IN ( 1 ) THEN b ELSE c END AS x, b, c FROM T WHERE c1 < c2 and c3 < c4;", + "SELECT T.* FROM T JOIN W ON T.x = W.x;", + "SELECT NULL;", + "SELECT ARRAY[x] FROM T;", + "SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "CREATE TABLE T AS SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "INSERT INTO T SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "SELECT ROW_NUMBER() OVER(PARTITION BY x) FROM T;", + "SELECT x, SUM(y) OVER (PARTITION BY y ORDER BY 1) AS min\n" + + "FROM (values ('b',10), ('a', 10)) AS T(x, y)\n;", + "SELECT\n" + + " CAST(MAP() AS map>) AS \"bool_tensor_features\";", + "SELECT f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f())))))))))))))))))))))))))))));", + "SELECT abs, 2 as abs;" + }; + + private static final String[] STATEMENT_BEFORE_REWRITE = new String[] { + // ORDER BY Anti-Pattern + "CREATE TABLE blah AS SELECT * FROM T ORDER BY y;", + "INSERT INTO blah SELECT * FROM (SELECT t.date, t.code, t.qty FROM sales AS t ORDER BY t.date) LIMIT 10;", + "CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT foo FROM T ORDER BY x LIMIT 10) ORDER BY y) ORDER BY z LIMIT 10;", + }; + + private static final String[] STATEMENT_AFTER_REWRITE = new String[] { + // ORDER BY Anti-Pattern + "CREATE TABLE blah AS SELECT * FROM T;", + "INSERT INTO blah SELECT * FROM (SELECT t.date, t.code, t.qty FROM sales AS t) LIMIT 10;", + "CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT foo FROM T ORDER BY x LIMIT 10)) ORDER BY z LIMIT 10;", + }; + + private static final RewriteDriverConfig USE_ALL_REWRITERS_CONFIG = new RewriteDriverConfig().setUserEnabledRewriters(ImmutableSet.of(OrderByRewriter.class)); + private static final int EXPECTED_SIZE_OF_REWRITE_RESULT_LIST = 1; + + private void assertStatementUnchanged(String originalStatement) + { + Optional result = getRewriteDriverResult(originalStatement, USE_ALL_REWRITERS_CONFIG); + assertFalse(result.isPresent()); + } + + private void assertStatementRewritten(String originalStatement, String expectedStatement) + { + Optional result = getRewriteDriverResult(originalStatement, USE_ALL_REWRITERS_CONFIG); + assertTrue(result.isPresent()); + assertEquals(result.get().getRewriteResults().size(), EXPECTED_SIZE_OF_REWRITE_RESULT_LIST); + AstNode rewrittenAst = result.get().getRewrittenSqlAsAst(); + String actualStatement = unparse(rewrittenAst).trim(); + assertEquals(actualStatement, expectedStatement); + } + + private Optional getRewriteDriverResult(String originalStatement, RewriteDriverConfig config) + { + AstNode ast = parseStatement(originalStatement); + assertNotNull(ast); + return new RewriteDriver(config, ast).applyRewriters(); + } + + @Test + public void applyAllRewritersTest() + { + for (int i = 0; i < STATEMENT_BEFORE_REWRITE.length; i++) { + assertStatementRewritten(STATEMENT_BEFORE_REWRITE[i], STATEMENT_AFTER_REWRITE[i]); + } + + for (String sql : STATEMENT_THAT_DOESNT_NEED_REWRITE) { + assertStatementUnchanged(sql); + } + } + + @Test + public void applyUnknownRewriterTest() + { + Rewriter unknownRewriter = new Rewriter() + { + @Override + public Optional rewrite() + { + return Optional.empty(); + } + }; + RewriteDriverConfig invalidConfig = new RewriteDriverConfig().setUserEnabledRewriters(ImmutableSet.of(unknownRewriter.getClass())); + Optional rewriteResult = getRewriteDriverResult(STATEMENT_BEFORE_REWRITE[0], invalidConfig); + assertFalse(rewriteResult.isPresent()); + } +}