diff --git a/examples/sushi/models/customer_revenue_by_day.sql b/examples/sushi/models/customer_revenue_by_day.sql index 3b7f3724cb..248af2db8d 100644 --- a/examples/sushi/models/customer_revenue_by_day.sql +++ b/examples/sushi/models/customer_revenue_by_day.sql @@ -21,7 +21,7 @@ WITH order_total AS ( LEFT JOIN sushi.items AS i ON oi.item_id = i.id AND oi.event_date = i.event_date WHERE - oi.event_date BETWEEN CAST('{{ start_ds }}' as DATE) AND CAST('{{ end_ds }}' as DATE) + oi.event_date BETWEEN @start_date AND @end_date GROUP BY oi.order_id, oi.event_date @@ -35,7 +35,7 @@ FROM sushi.orders AS o LEFT JOIN order_total AS ot ON o.id = ot.order_id AND o.event_date = ot.event_date WHERE - o.event_date BETWEEN CAST('{{ start_ds }}' as DATE) AND CAST('{{ end_ds }}' as DATE) + o.event_date BETWEEN @start_date AND @end_date GROUP BY o.customer_id, o.event_date diff --git a/examples/sushi/models/waiter_as_customer_by_day.sql b/examples/sushi/models/waiter_as_customer_by_day.sql index 7dc12db873..dd9f79b5a3 100644 --- a/examples/sushi/models/waiter_as_customer_by_day.sql +++ b/examples/sushi/models/waiter_as_customer_by_day.sql @@ -27,6 +27,6 @@ SELECT FROM sushi.waiters AS w JOIN sushi.customers as c ON w.waiter_id = c.customer_id JOIN sushi.waiter_names as wn ON w.waiter_id = wn.id -WHERE w.event_date BETWEEN @start_date AND @end_date; +WHERE w.event_date BETWEEN CAST('{{ start_ds }}' as DATE) AND @end_date; JINJA_END; diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 42a4a8b8dc..9e7df5d111 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -12,7 +12,6 @@ from datetime import datetime, date import sqlglot -from jinja2 import Environment from sqlglot import Generator, exp, parse_one from sqlglot.executor.env import ENV from sqlglot.executor.python import Python @@ -40,7 +39,6 @@ ) from sqlmesh.utils.date import DatetimeRanges, to_datetime, to_date from sqlmesh.utils.errors import MacroEvalError, SQLMeshError -from sqlmesh.utils.jinja import JinjaMacroRegistry, has_jinja from sqlmesh.utils.metaprogramming import ( Executable, SqlValue, @@ -193,7 +191,6 @@ def __init__( self.columns_to_types_called = False self.default_catalog = default_catalog - self._jinja_env: t.Optional[Environment] = None self._schema = schema self._resolve_table = resolve_table self._resolve_tables = resolve_tables @@ -282,12 +279,6 @@ def evaluate_macros( if node.this != text: changed = True return exp.to_identifier(text, quoted=node.quoted or None) - if node.is_string: - text = node.this - if has_jinja(text): - changed = True - node.set("this", self.jinja_env.from_string(node.this).render()) - return node if isinstance(node, MacroFunc): changed = True return self.evaluate(node) @@ -436,14 +427,6 @@ def parse_one( """ return sqlglot.maybe_parse(sql, dialect=self.dialect, into=into, **opts) - @property - def jinja_env(self) -> Environment: - if not self._jinja_env: - jinja_env_methods = {**self.locals, **self.env} - del jinja_env_methods["self"] - self._jinja_env = JinjaMacroRegistry().build_environment(**jinja_env_methods) - return self._jinja_env - def columns_to_types(self, model_name: TableName | exp.Column) -> t.Dict[str, exp.DataType]: """Returns the columns-to-types mapping corresponding to the specified model.""" diff --git a/tests/core/test_model.py b/tests/core/test_model.py index eecc3977e7..9266a56c10 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -2566,11 +2566,15 @@ def test_parse(assert_exp_eq): dialect '', ); + JINJA_QUERY_BEGIN; + SELECT id::INT AS id, ds FROM x - WHERE ds BETWEEN '{{ start_ds }}' AND @end_ds + WHERE ds BETWEEN '{{ start_ds }}' AND @end_ds; + + JINJA_END; """ ) model = load_sql_based_model(expressions, dialect="hive") @@ -2580,8 +2584,8 @@ def test_parse(assert_exp_eq): } assert not model.annotated assert model.dialect == "" - assert isinstance(model.query, exp.Select) - assert isinstance(SqlModel.parse_raw(model.json()).query, exp.Select) + assert isinstance(model.query, d.JinjaQuery) + assert isinstance(SqlModel.parse_raw(model.json()).query, d.JinjaQuery) assert_exp_eq( model.render_query(), """ @@ -11543,3 +11547,18 @@ def test_text_diff_optimize_query(): diff = model1.text_diff(model2) assert diff, "Expected diff to show optimize_query change" assert "+ optimize_query" in diff.lower() + + +def test_raw_jinja_raw_tag(): + expressions = d.parse( + """ + MODEL (name test); + + JINJA_QUERY_BEGIN; + SELECT {% raw %} '{{ foo }}' {% endraw %} AS col; + JINJA_END; + """ + ) + + model = load_sql_based_model(expressions) + assert model.render_query().sql() == "SELECT '{{ foo }}' AS \"col\""