diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 1a42480c13..f5464e12bc 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -13,6 +13,7 @@ from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp from sqlglot.dialects.dialect import DialectType from sqlglot.dialects import DuckDB, Snowflake +import sqlglot.dialects.athena as athena from sqlglot.helper import seq_get from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -1014,6 +1015,14 @@ def extend_sqlglot() -> None: generators = {Generator} for dialect in Dialect.classes.values(): + # Athena picks a different Tokenizer / Parser / Generator depending on the query + # so this ensures that the extra ones it defines are also extended + if dialect == athena.Athena: + tokenizers.add(athena._TrinoTokenizer) + parsers.add(athena._TrinoParser) + generators.add(athena._TrinoGenerator) + generators.add(athena._HiveGenerator) + if hasattr(dialect, "Tokenizer"): tokenizers.add(dialect.Tokenizer) if hasattr(dialect, "Parser"): diff --git a/tests/core/test_dialect.py b/tests/core/test_dialect.py index ebf90bebf7..11ffec3720 100644 --- a/tests/core/test_dialect.py +++ b/tests/core/test_dialect.py @@ -12,7 +12,9 @@ select_from_values_for_batch_range, text_diff, ) +import sqlmesh.core.dialect as d from sqlmesh.core.model import SqlModel, load_sql_based_model +from sqlmesh.core.config.connection import DIALECT_TO_TYPE def test_format_model_expressions(): @@ -700,3 +702,18 @@ def test_model_name_cannot_be_string(): def test_parse_snowflake_create_schema_ddl(): assert parse_one("CREATE SCHEMA d.s", dialect="snowflake").sql() == "CREATE SCHEMA d.s" + + +@pytest.mark.parametrize("dialect", sorted(set(DIALECT_TO_TYPE.values()))) +def test_sqlglot_extended_correctly(dialect: str) -> None: + # MODEL is a SQLMesh extension and not part of SQLGlot + # If we can roundtrip an expression containing MODEL across every dialect, then the SQLMesh extensions have been registered correctly + ast = d.parse_one("MODEL (name foo)", dialect=dialect) + assert isinstance(ast, d.Model) + name_prop = ast.find(exp.Property) + assert isinstance(name_prop, exp.Property) + assert name_prop.this == "name" + value = name_prop.args["value"] + assert isinstance(value, exp.Table) + assert value.sql() == "foo" + assert ast.sql(dialect=dialect) == "MODEL (\nname foo\n)"