Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects import DuckDB, Snowflake
import sqlglot.dialects.athena as athena
from sqlglot.helper import seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers
Expand Down Expand Up @@ -1014,6 +1015,14 @@ def extend_sqlglot() -> None:
generators = {Generator}

for dialect in Dialect.classes.values():
# Athena picks a different Tokenizer / Parser / Generator depending on the query
# so this ensures that the extra ones it defines are also extended
if dialect == athena.Athena:
tokenizers.add(athena._TrinoTokenizer)
parsers.add(athena._TrinoParser)
generators.add(athena._TrinoGenerator)
generators.add(athena._HiveGenerator)

if hasattr(dialect, "Tokenizer"):
tokenizers.add(dialect.Tokenizer)
if hasattr(dialect, "Parser"):
Expand Down
17 changes: 17 additions & 0 deletions tests/core/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
select_from_values_for_batch_range,
text_diff,
)
import sqlmesh.core.dialect as d
from sqlmesh.core.model import SqlModel, load_sql_based_model
from sqlmesh.core.config.connection import DIALECT_TO_TYPE


def test_format_model_expressions():
Expand Down Expand Up @@ -700,3 +702,18 @@ def test_model_name_cannot_be_string():

def test_parse_snowflake_create_schema_ddl():
assert parse_one("CREATE SCHEMA d.s", dialect="snowflake").sql() == "CREATE SCHEMA d.s"


@pytest.mark.parametrize("dialect", sorted(set(DIALECT_TO_TYPE.values())))
def test_sqlglot_extended_correctly(dialect: str) -> None:
# MODEL is a SQLMesh extension and not part of SQLGlot
# If we can roundtrip an expression containing MODEL across every dialect, then the SQLMesh extensions have been registered correctly
ast = d.parse_one("MODEL (name foo)", dialect=dialect)
assert isinstance(ast, d.Model)
name_prop = ast.find(exp.Property)
assert isinstance(name_prop, exp.Property)
assert name_prop.this == "name"
value = name_prop.args["value"]
assert isinstance(value, exp.Table)
assert value.sql() == "foo"
assert ast.sql(dialect=dialect) == "MODEL (\nname foo\n)"