diff --git a/src/embed.ts b/src/embed.ts index 3f99bad..35218be 100644 --- a/src/embed.ts +++ b/src/embed.ts @@ -2,7 +2,8 @@ import { Printer } from "prettier"; import { Node } from "sql-parser-cst"; import { embedJs } from "./embedJs"; import { embedJson } from "./embedJson"; +import { embedSql } from "./embedSql"; export const embed: NonNullable["embed"]> = (...args) => { - return embedJson(...args) || embedJs(...args); + return embedJson(...args) || embedJs(...args) || embedSql(...args); }; diff --git a/src/embedSql.ts b/src/embedSql.ts new file mode 100644 index 0000000..5fe67d2 --- /dev/null +++ b/src/embedSql.ts @@ -0,0 +1,67 @@ +import { Printer } from "prettier"; +import { + CreateFunctionStmt, + CreateProcedureStmt, + Node, + StringLiteral +} from "sql-parser-cst"; +import { + isAsClause, + isCreateFunctionStmt, + isCreateProcedureStmt, + isLanguageClause, + isStringLiteral, +} from "./node_utils"; +import { hardline, indent, stripTrailingHardline } from "./print_utils"; + +export const embedSql: NonNullable["embed"]> = (path, options) => { + const node = path.node; + const parent = path.getParentNode(0); + const grandParent = path.getParentNode(1); + + if ( + isStringLiteral(node) && + isAsClause(parent) && + (isCreateFunctionStmt(grandParent) || isCreateProcedureStmt(grandParent)) && + grandParent.clauses.some(isSqlLanguageClause) + ) { + return async (textToDoc) => { + let quote = detectQuote(node); + + if (!quote) { + return; + } + + if (quote === "'") { + // Convert `'` quotes to `$$` to simplify handling of strings inside the + // function. But bail out if the function contains dollar-quoted strings. + if (node.value.includes("$$")) { + return; + } + quote = "$$"; + } + + const sql = await textToDoc(node.value, options); + + return [ + quote, + indent([hardline, stripTrailingHardline(sql)]), + hardline, + quote, + ]; + }; + } + + return null; +}; + +const isSqlLanguageClause = ( + clause: CreateFunctionStmt["clauses"][0] | CreateProcedureStmt['clauses'][0], +): boolean => isLanguageClause(clause) && clause.name.name.toLowerCase() === "sql"; + +const detectQuote = ( + node: StringLiteral, +): string | undefined => { + const match = node.text.match(/^('|\$[^$]*\$)/); + return match?.[1]; +}; diff --git a/test/ddl/function.test.ts b/test/ddl/function.test.ts index b5d8abc..e44d6f6 100644 --- a/test/ddl/function.test.ts +++ b/test/ddl/function.test.ts @@ -202,6 +202,101 @@ describe("function", () => { AS " return /'''|\\"\\"\\"/.test(x) " `); }); + + it(`formats dollar-quoted SQL function`, async () => { + await testPostgresql(dedent` + CREATE FUNCTION my_func() + RETURNS INT64 + LANGUAGE sql + AS $$ + SELECT 1; + $$ + `); + }); + + it(`reformats SQL in dollar-quoted SQL function`, async () => { + expect( + await pretty( + dedent` + CREATE FUNCTION my_func() + RETURNS INT64 + LANGUAGE sql + AS $body$SELECT 1; + select 2$body$ + `, + { dialect: "postgresql" }, + ), + ).toBe(dedent` + CREATE FUNCTION my_func() + RETURNS INT64 + LANGUAGE sql + AS $body$ + SELECT 1; + SELECT 2; + $body$ + `); + }); + + it(`converts single-quoted SQL functions to dollar-quoted SQL functions`, async () => { + expect( + await pretty( + dedent` + CREATE FUNCTION my_func() + RETURNS TEXT + LANGUAGE sql + AS 'SELECT ''foo''' + `, + { dialect: "postgresql" }, + ), + ).toBe(dedent` + CREATE FUNCTION my_func() + RETURNS TEXT + LANGUAGE sql + AS $$ + SELECT 'foo'; + $$ + `); + }); + + it(`does not convert single-quoted SQL functions to dollar-quoted SQL functions when they contain dollar-quoted strings`, async () => { + expect( + await pretty( + dedent` + CREATE FUNCTION my_func() + RETURNS TEXT + LANGUAGE sql + AS 'SELECT $$foo$$' + `, + { dialect: "postgresql" }, + ), + ).toBe(dedent` + CREATE FUNCTION my_func() + RETURNS TEXT + LANGUAGE sql + AS 'SELECT $$foo$$' + `); + }); + + it(`handles SQL language identifier case-insensitively`, async () => { + expect( + await pretty( + dedent` + CREATE FUNCTION my_func() + RETURNS INT64 + LANGUAGE Sql + AS 'SELECT 1' + `, + { dialect: "postgresql" }, + ), + ).toBe(dedent` + CREATE FUNCTION my_func() + RETURNS INT64 + LANGUAGE Sql + AS $$ + SELECT 1; + $$ + `); + }); }); describe("drop function", () => { diff --git a/test/ddl/procedure.test.ts b/test/ddl/procedure.test.ts index aa60597..4b99c96 100644 --- a/test/ddl/procedure.test.ts +++ b/test/ddl/procedure.test.ts @@ -1,5 +1,5 @@ import dedent from "dedent-js"; -import { testBigquery, testPostgresql } from "../test_utils"; +import { pretty, testBigquery, testPostgresql } from "../test_utils"; describe("procedure", () => { describe("create procedure", () => { @@ -90,6 +90,92 @@ describe("procedure", () => { `, ); }); + + it(`formats dollar-quoted SQL procedure`, async () => { + await testPostgresql(dedent` + CREATE PROCEDURE my_proc() + LANGUAGE sql + AS $$ + SELECT 1; + $$ + `); + }); + + it(`reformats SQL in dollar-quoted SQL procedure`, async () => { + expect( + await pretty( + dedent` + CREATE PROCEDURE my_proc() + LANGUAGE sql + AS $body$SELECT 1; + select 2$body$ + `, + { dialect: "postgresql" }, + ), + ).toBe(dedent` + CREATE PROCEDURE my_proc() + LANGUAGE sql + AS $body$ + SELECT 1; + SELECT 2; + $body$ + `); + }); + + it(`converts single-quoted SQL procedures to dollar-quoted SQL procedures`, async () => { + expect( + await pretty( + dedent` + CREATE PROCEDURE my_proc() + LANGUAGE sql + AS 'SELECT ''foo''' + `, + { dialect: "postgresql" }, + ), + ).toBe(dedent` + CREATE PROCEDURE my_proc() + LANGUAGE sql + AS $$ + SELECT 'foo'; + $$ + `); + }); + + it(`does not convert single-quoted SQL procedures to dollar-quoted SQL procedures when they contain dollar-quoted strings`, async () => { + expect( + await pretty( + dedent` + CREATE PROCEDURE my_proc() + LANGUAGE sql + AS 'SELECT $$foo$$' + `, + { dialect: "postgresql" }, + ), + ).toBe(dedent` + CREATE PROCEDURE my_proc() + LANGUAGE sql + AS 'SELECT $$foo$$' + `); + }); + + it(`handles SQL language identifier case-insensitively`, async () => { + expect( + await pretty( + dedent` + CREATE PROCEDURE my_proc() + LANGUAGE Sql + AS 'SELECT 1' + `, + { dialect: "postgresql" }, + ), + ).toBe(dedent` + CREATE PROCEDURE my_proc() + LANGUAGE Sql + AS $$ + SELECT 1; + $$ + `); + }); }); describe("drop procedure", () => {