From f0798ba5003db29499ee9bd975763a489b62df5a Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Thu, 28 Aug 2025 10:12:38 -0700 Subject: [PATCH 1/8] Fix!: Avoid using rendered query when computing the data hash --- sqlmesh/core/context_diff.py | 9 ++-- sqlmesh/core/model/definition.py | 42 +++++++++++++++--- sqlmesh/core/node.py | 44 ++++++++++++++----- sqlmesh/core/snapshot/categorizer.py | 11 ++--- sqlmesh/core/snapshot/definition.py | 15 +++++++ ...093_use_unrendered_query_in_fingerprint.py | 5 +++ tests/core/test_integration.py | 16 +++---- tests/core/test_model.py | 2 +- tests/core/test_selector.py | 2 +- tests/core/test_snapshot.py | 4 +- 10 files changed, 110 insertions(+), 40 deletions(-) create mode 100644 sqlmesh/migrations/v0093_use_unrendered_query_in_fingerprint.py diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py index 12da39f50f..07d13b1c2f 100644 --- a/sqlmesh/core/context_diff.py +++ b/sqlmesh/core/context_diff.py @@ -435,7 +435,7 @@ def directly_modified(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return current.fingerprint.data_hash != previous.fingerprint.data_hash + return current.is_directly_modified(previous) def indirectly_modified(self, name: str) -> bool: """Returns whether or not a node was indirectly modified in this context. @@ -451,10 +451,7 @@ def indirectly_modified(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return ( - current.fingerprint.data_hash == previous.fingerprint.data_hash - and current.fingerprint.parent_data_hash != previous.fingerprint.parent_data_hash - ) + return current.is_indirectly_modified(previous) def metadata_updated(self, name: str) -> bool: """Returns whether or not the given node's metadata has been updated. @@ -470,7 +467,7 @@ def metadata_updated(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return current.fingerprint.metadata_hash != previous.fingerprint.metadata_hash + return current.is_metadata_updated(previous) def text_diff(self, name: str) -> str: """Finds the difference of a node between the current and remote environment. diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index dba8eedc31..d8eb852c1e 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -62,6 +62,7 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType + from sqlmesh.core.node import _Node from sqlmesh.core._typing import Self, TableName, SessionProperties from sqlmesh.core.context import ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter @@ -1278,6 +1279,7 @@ class SqlModel(_Model): source_type: t.Literal["sql"] = "sql" _columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None + _is_metadata_only_change_cache: t.Dict[int, bool] = {} def __getstate__(self) -> t.Dict[t.Any, t.Any]: state = super().__getstate__() @@ -1500,6 +1502,27 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: return False + def is_metadata_only_change(self, previous: _Node) -> bool: + if self._is_metadata_only_change_cache.get(id(previous), None) is not None: + return self._is_metadata_only_change_cache[id(previous)] + + if ( + not isinstance(previous, SqlModel) + or self.metadata_hash == previous.metadata_hash + or self._data_hash_values_no_query != previous._data_hash_values_no_query + ): + is_metadata_change = False + else: + # If the rendered queries are the same, then this is a metadata only change + this_rendered_query = self.render_query() + previous_rendered_query = previous.render_query() + is_metadata_change = ( + this_rendered_query is not None and this_rendered_query == previous_rendered_query + ) + + self._is_metadata_only_change_cache[id(previous)] = is_metadata_change + return is_metadata_change + @cached_property def _query_renderer(self) -> QueryRenderer: no_quote_identifiers = self.kind.is_view and self.dialect in ("trino", "spark") @@ -1519,17 +1542,22 @@ def _query_renderer(self) -> QueryRenderer: ) @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_query(self) -> t.List[str]: + return [ + *super()._data_hash_values, + *self.jinja_macros.data_hash_values, + ] - query = self.render_query() or self.query - data.append(gen(query)) - data.extend(self.jinja_macros.data_hash_values) - return data + @property + def _data_hash_values(self) -> t.List[str]: + return [ + *self._data_hash_values_no_query, + gen(self.query, comments=False), + ] @property def _additional_metadata(self) -> t.List[str]: - return [*super()._additional_metadata, gen(self.query)] + return [*super()._additional_metadata, gen(self.query, comments=True)] @property def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index 4f0a66dc2e..ea2264f7fa 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -307,16 +307,6 @@ def batch_concurrency(self) -> t.Optional[int]: """The maximal number of batches that can run concurrently for a backfill.""" return None - @property - def data_hash(self) -> str: - """ - Computes the data hash for the node. - - Returns: - The data hash for the node. - """ - raise NotImplementedError - @property def interval_unit(self) -> IntervalUnit: """Returns the interval unit using which data intervals are computed for this node.""" @@ -332,6 +322,16 @@ def depends_on(self) -> t.Set[str]: def fqn(self) -> str: return self.name + @property + def data_hash(self) -> str: + """ + Computes the data hash for the node. + + Returns: + The data hash for the node. + """ + raise NotImplementedError + @property def metadata_hash(self) -> str: """ @@ -342,6 +342,30 @@ def metadata_hash(self) -> str: """ raise NotImplementedError + def is_metadata_only_change(self, previous: _Node) -> bool: + """Determines if this node is a metadata only change in relation to the `previous` node. + + Args: + previous: The previous node to compare against. + + Returns: + True if this node is a metadata only change, False otherwise. + """ + return self.data_hash == previous.data_hash and self.metadata_hash != previous.metadata_hash + + def is_data_change(self, previous: _Node) -> bool: + """Determines if this node is a data change in relation to the `previous` node. + + Args: + previous: The previous node to compare against. + + Returns: + True if this node is a data change, False otherwise. + """ + return ( + self.data_hash != previous.data_hash or self.metadata_hash != previous.metadata_hash + ) and not self.is_metadata_only_change(previous) + def croniter(self, value: TimeLike) -> CroniterCache: if self._croniter is None: self._croniter = CroniterCache(self.cron, value, tz=self.cron_tz) diff --git a/sqlmesh/core/snapshot/categorizer.py b/sqlmesh/core/snapshot/categorizer.py index 88a1ef37ab..78ea7466ed 100644 --- a/sqlmesh/core/snapshot/categorizer.py +++ b/sqlmesh/core/snapshot/categorizer.py @@ -47,11 +47,12 @@ def categorize_change( if type(new_model) != type(old_model): return default_category - if new.fingerprint.data_hash == old.fingerprint.data_hash: - if new.fingerprint.metadata_hash == old.fingerprint.metadata_hash: - raise SQLMeshError( - f"{new} is unmodified or indirectly modified and should not be categorized" - ) + if new.fingerprint == old.fingerprint: + raise SQLMeshError( + f"{new} is unmodified or indirectly modified and should not be categorized" + ) + + if not new.is_directly_modified(old): if new.fingerprint.parent_data_hash == old.fingerprint.parent_data_hash: return SnapshotChangeCategory.NON_BREAKING return None diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index afc8e06458..dea4ef64e5 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1230,6 +1230,21 @@ def apply_pending_restatement_intervals(self) -> None: ) self.intervals = remove_interval(self.intervals, *pending_restatement_interval) + def is_directly_modified(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot is directly modified in relation to the other snapshot.""" + return self.node.is_data_change(other.node) + + def is_indirectly_modified(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot is indirectly modified in relation to the other snapshot.""" + return ( + self.fingerprint.parent_data_hash != other.fingerprint.parent_data_hash + and not self.node.is_data_change(other.node) + ) + + def is_metadata_updated(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot contains metadata changes in relation to the other snapshot.""" + return self.fingerprint.metadata_hash != other.fingerprint.metadata_hash + @property def physical_schema(self) -> str: if self.physical_schema_ is not None: diff --git a/sqlmesh/migrations/v0093_use_unrendered_query_in_fingerprint.py b/sqlmesh/migrations/v0093_use_unrendered_query_in_fingerprint.py new file mode 100644 index 0000000000..8698d32831 --- /dev/null +++ b/sqlmesh/migrations/v0093_use_unrendered_query_in_fingerprint.py @@ -0,0 +1,5 @@ +"""Use the unrendered query when computing the model fingerprint.""" + + +def migrate(state_sync, **kwargs): # type: ignore + pass diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 0e779481fd..5100732f62 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -4681,12 +4681,12 @@ def test_plan_repairs_unrenderable_snapshot_state( f"name = '{target_snapshot.name}' AND identifier = '{target_snapshot.identifier}'", ) + context.clear_caches() + target_snapshot_in_state = context.state_sync.get_snapshots([target_snapshot.snapshot_id])[ + target_snapshot.snapshot_id + ] + with pytest.raises(Exception): - context_copy = context.copy() - context_copy.clear_caches() - target_snapshot_in_state = context_copy.state_sync.get_snapshots( - [target_snapshot.snapshot_id] - )[target_snapshot.snapshot_id] target_snapshot_in_state.model.render_query_or_raise() # Repair the snapshot by creating a new version of it @@ -4695,11 +4695,11 @@ def test_plan_repairs_unrenderable_snapshot_state( plan_builder = context.plan_builder("prod", forward_only=forward_only) plan = plan_builder.build() - assert plan.directly_modified == {target_snapshot.snapshot_id} if not forward_only: assert target_snapshot.snapshot_id in {i.snapshot_id for i in plan.missing_intervals} - plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING) - plan = plan_builder.build() + assert plan.directly_modified == {target_snapshot.snapshot_id} + plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING) + plan = plan_builder.build() context.apply(plan) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 3850e08164..804d702f3b 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -5732,7 +5732,7 @@ def test_default_catalog_sql(assert_exp_eq): The system is not designed to actually support having an engine that doesn't support default catalog to start supporting it or the reverse of that. If that did happen then bugs would occur. """ - HASH_WITH_CATALOG = "1269513823" + HASH_WITH_CATALOG = "3443912775" # Test setting default catalog doesn't change hash if it matches existing logic expressions = d.parse( diff --git a/tests/core/test_selector.py b/tests/core/test_selector.py index 9f3bc9f698..ccdb8bfd80 100644 --- a/tests/core/test_selector.py +++ b/tests/core/test_selector.py @@ -301,7 +301,7 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot): selector = Selector(state_reader_mock, local_models) selected = selector.select_models(["db.parent"], env_name) - assert selected[local_child.fqn].data_hash != child.data_hash + assert selected[local_child.fqn].render_query() != child.render_query() _assert_models_equal( selected, diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 86fb434e33..231ecb8935 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -913,7 +913,7 @@ def test_fingerprint(model: Model, parent_model: Model): fingerprint = fingerprint_from_node(model, nodes={}) original_fingerprint = SnapshotFingerprint( - data_hash="3301649319", + data_hash="1698409777", metadata_hash="3575333731", ) @@ -1013,7 +1013,7 @@ def test_fingerprint_jinja_macros(model: Model): } ) original_fingerprint = SnapshotFingerprint( - data_hash="2908339239", + data_hash="343517722", metadata_hash="3575333731", ) From 77f4197b0969b25269c1a2878e3fa33c2cb4d83b Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 29 Aug 2025 10:35:41 -0700 Subject: [PATCH 2/8] switch to storing raw sql --- sqlmesh/core/model/common.py | 50 +++++++++++++++++++++++++++++++- sqlmesh/core/model/definition.py | 21 ++++++++++++-- tests/core/test_snapshot.py | 10 ++++--- 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 9a68ec18c0..1fb9e4ad1a 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -21,7 +21,7 @@ prepare_env, serialize_env, ) -from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, get_dialect if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -663,3 +663,51 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any: mode="before", check_fields=False, )(depends_on) + + +class ParsableSql(PydanticModel): + sql: str + + _parsed: t.Optional[exp.Expression] = None + _parsed_dialect: t.Optional[str] = None + + def parse(self, dialect: str) -> exp.Expression: + if self._parsed is None or self._parsed_dialect != dialect: + self._parsed = d.parse_one(self.sql, dialect=dialect) + self._parsed_dialect = dialect + return self._parsed + + @classmethod + def from_parsed_expression( + cls, parsed_expression: exp.Expression, dialect: str, use_meta_sql: bool = False + ) -> ParsableSql: + sql = ( + parsed_expression.meta.get("sql") or parsed_expression.sql(dialect=dialect) + if use_meta_sql + else parsed_expression.sql(dialect=dialect) + ) + result = cls(sql=sql) + result._parsed = parsed_expression + result._parsed_dialect = dialect + return result + + @classmethod + def validator(cls) -> classmethod: + def _validate_parsable_sql(v: t.Any, info: ValidationInfo) -> ParsableSql: + if isinstance(v, str): + return ParsableSql(sql=v) + if isinstance(v, exp.Expression): + return ParsableSql.from_parsed_expression( + v, get_dialect(info.data), use_meta_sql=False + ) + return ParsableSql.parse_obj(v) + + return field_validator( + "query_", + # "expressions_", + # "pre_statements_", + # "post_statements_", + # "on_virtual_update_", + mode="before", + check_fields=False, + )(_validate_parsable_sql) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index d8eb852c1e..7b091c3781 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -24,6 +24,7 @@ from sqlmesh.core.node import IntervalUnit from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.model.common import ( + ParsableSql, expression_validator, make_python_env, parse_dependencies, @@ -1275,9 +1276,10 @@ class SqlModel(_Model): on_virtual_update: The list of SQL statements to be executed after the virtual update. """ - query: t.Union[exp.Query, d.JinjaQuery, d.MacroFunc] + query_: ParsableSql = Field(alias="query") source_type: t.Literal["sql"] = "sql" + _query_validator = ParsableSql.validator() _columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None _is_metadata_only_change_cache: t.Dict[int, bool] = {} @@ -1300,6 +1302,11 @@ def copy(self, **kwargs: t.Any) -> Self: model._full_depends_on = None return model + @property + def query(self) -> t.Union[exp.Query, d.JinjaQuery, d.MacroFunc]: + parsed_query = self.query_.parse(self.dialect) + return t.cast(t.Union[exp.Query, d.JinjaQuery, d.MacroFunc], parsed_query) + def render_query( self, *, @@ -2280,6 +2287,7 @@ def load_sql_based_model( def create_sql_model( name: TableName, query: t.Optional[exp.Expression], + dialect: t.Optional[str] = None, **kwargs: t.Any, ) -> Model: """Creates a SQL model. @@ -2296,7 +2304,14 @@ def create_sql_model( ) assert isinstance(query, (exp.Query, d.JinjaQuery, d.MacroFunc)) - return _create_model(SqlModel, name, query=query, **kwargs) + dialect = dialect or "" + return _create_model( + SqlModel, + name, + query=ParsableSql.from_parsed_expression(query, dialect, use_meta_sql=True), + dialect=dialect, + **kwargs, + ) def create_seed_model( @@ -2531,7 +2546,7 @@ def _create_model( if "pre_statements" in kwargs: statements.extend(kwargs["pre_statements"]) if "query" in kwargs: - statements.append(kwargs["query"]) + statements.append(kwargs["query"].parse(dialect)) if "post_statements" in kwargs: statements.extend(kwargs["post_statements"]) diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 231ecb8935..81742f9a67 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -79,7 +79,7 @@ def parent_model(): name="parent.tbl", kind=dict(time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE), dialect="spark", - query=parse_one("SELECT 1, ds"), + query="SELECT 1, ds", ) @@ -92,7 +92,7 @@ def model(): dialect="spark", cron="1 0 * * *", start="2020-01-01", - query=parse_one("SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl"), + query="SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", ) @@ -148,7 +148,9 @@ def test_json(snapshot: Snapshot): "project": "", "python_env": {}, "owner": "owner", - "query": "SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", + "query": { + "sql": "SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", + }, "jinja_macros": { "create_builtins_module": "sqlmesh.utils.jinja", "global_objs": {}, @@ -186,7 +188,7 @@ def test_json_custom_materialization(make_snapshot: t.Callable): dialect="spark", cron="1 0 * * *", start="2020-01-01", - query=parse_one("SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl"), + query="SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", ) snapshot = make_snapshot( From 353dbdd8af3e040240b9561e2558856833ba08d8 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 29 Aug 2025 11:41:16 -0700 Subject: [PATCH 3/8] fix tests --- tests/core/test_context.py | 6 +++- tests/core/test_integration.py | 60 ++++++++++++++++++++++++++++------ tests/core/test_model.py | 5 ++- tests/core/test_selector.py | 9 ++++- tests/core/test_test.py | 11 +++++-- 5 files changed, 76 insertions(+), 15 deletions(-) diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 3b7c5bd51d..196889a87c 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -39,6 +39,7 @@ from sqlmesh.core.plan.definition import Plan from sqlmesh.core.macros import MacroEvaluator, RuntimeStage from sqlmesh.core.model import load_sql_based_model, model, SqlModel, Model +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.model.cache import OptimizedQueryCache from sqlmesh.core.renderer import render_statements from sqlmesh.core.model.kind import ModelKindName @@ -2303,7 +2304,10 @@ def test_prompt_if_uncategorized_snapshot(mocker: MockerFixture, tmp_path: Path) incremental_model = context.get_model("sqlmesh_example.incremental_model") incremental_model_query = incremental_model.render_query() new_incremental_model_query = t.cast(exp.Select, incremental_model_query).select("1 AS z") - context.upsert_model("sqlmesh_example.incremental_model", query=new_incremental_model_query) + context.upsert_model( + "sqlmesh_example.incremental_model", + query_=ParsableSql(sql=new_incremental_model_query.sql(dialect=incremental_model.dialect)), + ) mock_console = mocker.Mock() spy_plan = mocker.spy(mock_console, "plan") diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 5100732f62..a0b77131e5 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -16,6 +16,7 @@ from pathlib import Path from sqlmesh.core.console import set_console, get_console, TerminalConsole from sqlmesh.core.config.naming import NameInferenceConfig +from sqlmesh.core.model.common import ParsableSql from sqlmesh.utils.concurrency import NodeExecutionFailedError import time_machine from pytest_mock.plugin import MockerFixture @@ -2023,7 +2024,7 @@ def test_dbt_select_star_is_directly_modified(sushi_test_dbt_context: Context): model = context.get_model("sushi.simple_model_a") context.upsert_model( model, - query=d.parse_one("SELECT 1 AS a, 2 AS b"), + query_=ParsableSql(sql="SELECT 1 AS a, 2 AS b"), ) snapshot_a_id = context.get_snapshot("sushi.simple_model_a").snapshot_id # type: ignore @@ -2605,8 +2606,8 @@ def test_unaligned_start_snapshot_with_non_deployable_downstream(init_and_plan_c context.upsert_model(SqlModel.parse_obj(kwargs)) context.upsert_model( downstream_model_name, - query=d.parse_one( - "SELECT customer_id, MAX(revenue) AS max_revenue FROM memory.sushi.customer_revenue_lifetime_new GROUP BY 1" + query_=ParsableSql( + sql="SELECT customer_id, MAX(revenue) AS max_revenue FROM memory.sushi.customer_revenue_lifetime_new GROUP BY 1" ), ) @@ -2637,7 +2638,13 @@ def test_virtual_environment_mode_dev_only(init_and_plan_context: t.Callable): # Make a change in dev original_model = context.get_model("sushi.waiter_revenue_by_day") original_fingerprint = context.get_snapshot(original_model.name).fingerprint - model = original_model.copy(update={"query": original_model.query.order_by("waiter_id")}) + model = original_model.copy( + update={ + "query_": ParsableSql( + sql=original_model.query.order_by("waiter_id").sql(dialect=original_model.dialect) + ) + } + ) model = add_projection_to_model(t.cast(SqlModel, model)) context.upsert_model(model) @@ -5383,7 +5390,10 @@ def test_auto_categorization(sushi_context: Context): ).fingerprint model = t.cast(SqlModel, sushi_context.get_model("sushi.customers", raise_if_missing=True)) - sushi_context.upsert_model("sushi.customers", query=model.query.select("'foo' AS foo")) # type: ignore + sushi_context.upsert_model( + "sushi.customers", + query_=ParsableSql(sql=model.query.select("'foo' AS foo").sql(dialect=model.dialect)), # type: ignore + ) apply_to_environment(sushi_context, environment) assert ( @@ -5447,7 +5457,13 @@ def test_multi(mocker): model = context.get_model("bronze.a") assert model.project == "repo_1" - context.upsert_model(model.copy(update={"query": model.query.select("'c' AS c")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql(sql=model.query.select("'c' AS c").sql(dialect=model.dialect)) + } + ) + ) plan = context.plan_builder().build() assert set(snapshot.name for snapshot in plan.directly_modified) == { @@ -5615,7 +5631,15 @@ def test_multi_virtual_layer(copy_to_temp_path): model = context.get_model("db_1.first_schema.model_one") - context.upsert_model(model.copy(update={"query": model.query.select("'c' AS extra")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'c' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) plan = context.plan_builder().build() context.apply(plan) @@ -5641,9 +5665,25 @@ def test_multi_virtual_layer(copy_to_temp_path): # Create dev environment with changed models model = context.get_model("db_2.second_schema.model_one") - context.upsert_model(model.copy(update={"query": model.query.select("'d' AS extra")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) model = context.get_model("first_schema.model_two") - context.upsert_model(model.copy(update={"query": model.query.select("'d2' AS col")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d2' AS col").sql(dialect=model.dialect) + ) + } + ) + ) plan = context.plan_builder("dev").build() context.apply(plan) @@ -6634,7 +6674,7 @@ def change_data_type( for data_type in data_types: if data_type.this == old_type: data_type.set("this", new_type) - context.upsert_model(model_name, query=model.query) + context.upsert_model(model_name, query_=model.query_) elif model.columns_to_types_ is not None: for k, v in model.columns_to_types_.items(): if v.this == old_type: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 804d702f3b..8db1ca593a 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -22,6 +22,7 @@ from sqlmesh.core import dialect as d from sqlmesh.core.console import get_console from sqlmesh.core.audit import ModelAudit, load_audit +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.config import ( Config, DuckDBConnectionConfig, @@ -8874,7 +8875,9 @@ def test_column_description_metadata_change(): context.upsert_model(model) context.plan(no_prompts=True, auto_apply=True) - context.upsert_model("db.test_model", query=parse_one("SELECT 1 AS id /* description 2 */")) + context.upsert_model( + "db.test_model", query_=ParsableSql(sql="SELECT 1 AS id /* description 2 */") + ) plan = context.plan(no_prompts=True, auto_apply=True) snapshots = list(plan.snapshots.values()) diff --git a/tests/core/test_selector.py b/tests/core/test_selector.py index ccdb8bfd80..80b9ef691e 100644 --- a/tests/core/test_selector.py +++ b/tests/core/test_selector.py @@ -11,6 +11,7 @@ from sqlmesh.core.audit import StandaloneAudit from sqlmesh.core.environment import Environment from sqlmesh.core.model import Model, SqlModel +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.selector import Selector from sqlmesh.core.snapshot import SnapshotChangeCategory from sqlmesh.utils import UniqueKeyDict @@ -293,7 +294,13 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot): } local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - local_parent = parent.copy(update={"query": parent.query.select("2 as b", append=False)}) # type: ignore + local_parent = parent.copy( + update={ + "query_": ParsableSql( + sql=parent.query.select("2 as b", append=False).sql(dialect=parent.dialect) # type: ignore + ) + } + ) local_models[local_parent.fqn] = local_parent local_child = child.copy(update={"mapping_schema": {'"db"': {'"parent"': {"b": "INT"}}}}) local_models[local_child.fqn] = local_child diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 1b5425068f..d889c7bb33 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -29,6 +29,7 @@ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.macros import MacroEvaluator, macro from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.core.test.context import TestExecutionContext @@ -1985,12 +1986,18 @@ def test_test_generation(tmp_path: Path) -> None: ) context = Context(paths=tmp_path, config=config) - query = context.get_model("sqlmesh_example.full_model").render_query() + model = context.get_model("sqlmesh_example.full_model") + query = model.render_query() assert isinstance(query, exp.Query) context.upsert_model( "sqlmesh_example.full_model", - query=exp.select(*query.named_selects).from_("cte").with_("cte", as_=query), + query_=ParsableSql( + sql=exp.select(*query.named_selects) + .from_("cte") + .with_("cte", as_=query) + .sql(dialect=model.dialect) + ), ) context.plan(auto_apply=True) From 07b0ec514416ac06351edf17d605fdfb374c1360 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 29 Aug 2025 11:53:16 -0700 Subject: [PATCH 4/8] switch to storing raw sql in audits --- sqlmesh/core/audit/definition.py | 39 +++++++++++++-------- sqlmesh/core/model/common.py | 22 +++++++++--- tests/core/state_sync/test_export_import.py | 4 +-- tests/core/test_integration.py | 30 ++++++++++++++-- 4 files changed, 72 insertions(+), 23 deletions(-) diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py index 210ae9da1b..ae15efd574 100644 --- a/sqlmesh/core/audit/definition.py +++ b/sqlmesh/core/audit/definition.py @@ -15,11 +15,10 @@ bool_validator, default_catalog_validator, depends_on_validator, - expression_validator, sort_python_env, sorted_python_env_payloads, ) -from sqlmesh.core.model.common import make_python_env, single_value_or_tuple +from sqlmesh.core.model.common import make_python_env, single_value_or_tuple, ParsableSql from sqlmesh.core.node import _Node from sqlmesh.core.renderer import QueryRenderer from sqlmesh.utils.date import TimeLike @@ -67,15 +66,21 @@ class AuditMixin(AuditCommonMetaMixin): jinja_macros: A registry of jinja macros to use when rendering the audit query. """ - query: t.Union[exp.Query, d.JinjaQuery] + query_: ParsableSql defaults: t.Dict[str, exp.Expression] - expressions_: t.Optional[t.List[exp.Expression]] + expressions_: t.Optional[t.List[ParsableSql]] jinja_macros: JinjaMacroRegistry formatting: t.Optional[bool] + @property + def query(self) -> t.Union[exp.Query, d.JinjaQuery]: + return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect)) + @property def expressions(self) -> t.List[exp.Expression]: - return self.expressions_ or [] + if self.expressions_: + return [e.parse(self.dialect) for e in self.expressions_] + return [] @property def macro_definitions(self) -> t.List[d.MacroDef]: @@ -122,16 +127,16 @@ class ModelAudit(PydanticModel, AuditMixin, frozen=True): skip: bool = False blocking: bool = True standalone: t.Literal[False] = False - query: t.Union[exp.Query, d.JinjaQuery] + query_: ParsableSql = Field(alias="query") defaults: t.Dict[str, exp.Expression] = {} - expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions") + expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions") jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() formatting: t.Optional[bool] = Field(default=None, exclude=True) _path: t.Optional[Path] = None # Validators - _query_validator = expression_validator + _query_validator = ParsableSql.validator() _bool_validator = bool_validator _string_validator = audit_string_validator _map_validator = audit_map_validator @@ -153,9 +158,9 @@ class StandaloneAudit(_Node, AuditMixin): skip: bool = False blocking: bool = False standalone: t.Literal[True] = True - query: t.Union[exp.Query, d.JinjaQuery] + query_: ParsableSql = Field(alias="query") defaults: t.Dict[str, exp.Expression] = {} - expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions") + expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions") jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() default_catalog: t.Optional[str] = None depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on") @@ -165,7 +170,7 @@ class StandaloneAudit(_Node, AuditMixin): source_type: t.Literal["audit"] = "audit" # Validators - _query_validator = expression_validator + _query_validator = ParsableSql.validator() _bool_validator = bool_validator _string_validator = audit_string_validator _map_validator = audit_map_validator @@ -461,11 +466,17 @@ def load_audit( if project is not None: extra_kwargs["project"] = project - dialect = meta_fields.pop("dialect", dialect) + dialect = meta_fields.pop("dialect", dialect) or "" + + parsable_query = ParsableSql.from_parsed_expression(query, dialect, use_meta_sql=True) + parsable_statements = [ + ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=True) for s in statements + ] + try: audit = audit_class( - query=query, - expressions=statements, + query=parsable_query, + expressions=parsable_statements, dialect=dialect, **extra_kwargs, **meta_fields, diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 1fb9e4ad1a..19e513c3b2 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -616,8 +616,8 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any: expression_validator: t.Callable = field_validator( - "query", - "expressions_", + # "query", + # "expressions_", "pre_statements_", "post_statements_", "on_virtual_update_", @@ -693,18 +693,32 @@ def from_parsed_expression( @classmethod def validator(cls) -> classmethod: - def _validate_parsable_sql(v: t.Any, info: ValidationInfo) -> ParsableSql: + def _validate_parsable_sql( + v: t.Any, info: ValidationInfo + ) -> t.Optional[t.Union[ParsableSql, t.List[ParsableSql]]]: + if v is None: + return v if isinstance(v, str): return ParsableSql(sql=v) if isinstance(v, exp.Expression): return ParsableSql.from_parsed_expression( v, get_dialect(info.data), use_meta_sql=False ) + if isinstance(v, list): + dialect = get_dialect(info.data) + return [ + ParsableSql(sql=s) + if isinstance(s, str) + else ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=False) + if isinstance(s, exp.Expression) + else ParsableSql.parse_obj(s) + for s in v + ] return ParsableSql.parse_obj(v) return field_validator( "query_", - # "expressions_", + "expressions_", # "pre_statements_", # "post_statements_", # "on_virtual_update_", diff --git a/tests/core/state_sync/test_export_import.py b/tests/core/state_sync/test_export_import.py index 2d20199d33..c303a63e59 100644 --- a/tests/core/state_sync/test_export_import.py +++ b/tests/core/state_sync/test_export_import.py @@ -289,8 +289,8 @@ def test_export_local_state( full_model = next(s for s in snapshots if "full_model" in s["name"]) new_model = next(s for s in snapshots if "new_model" in s["name"]) - assert "'1' as modified" in full_model["node"]["query"] - assert "SELECT 1 as id" in new_model["node"]["query"] + assert "'1' as modified" in full_model["node"]["query"]["sql"] + assert "SELECT 1 as id" in new_model["node"]["query"]["sql"] def test_import_invalid_file(tmp_path: Path, state_sync: StateSync) -> None: diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index a0b77131e5..cd36049a0d 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -6961,7 +6961,15 @@ def test_destroy(copy_to_temp_path): model = context.get_model("db_1.first_schema.model_one") - context.upsert_model(model.copy(update={"query": model.query.select("'c' AS extra")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'c' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) plan = context.plan_builder().build() context.apply(plan) @@ -6972,9 +6980,25 @@ def test_destroy(copy_to_temp_path): # Create dev environment with changed models model = context.get_model("db_2.second_schema.model_one") - context.upsert_model(model.copy(update={"query": model.query.select("'d' AS extra")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) model = context.get_model("first_schema.model_two") - context.upsert_model(model.copy(update={"query": model.query.select("'d2' AS col")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d2' AS col").sql(dialect=model.dialect) + ) + } + ) + ) plan = context.plan_builder("dev").build() context.apply(plan) From 6062c94b96395a6c2ccd8400f5d59f047b681107 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 29 Aug 2025 12:27:47 -0700 Subject: [PATCH 5/8] use raw sql for pre- / post- statements --- sqlmesh/core/audit/definition.py | 15 +- sqlmesh/core/model/common.py | 11 +- sqlmesh/core/model/definition.py | 214 ++++++++++-------- .../v0093_use_raw_sql_in_fingerprint.py | 5 + ...093_use_unrendered_query_in_fingerprint.py | 5 - tests/core/test_integration.py | 5 +- tests/core/test_model.py | 21 +- tests/core/test_snapshot.py | 14 +- tests/core/test_snapshot_evaluator.py | 3 + tests/core/test_table_diff.py | 11 +- .../github/cicd/test_integration.py | 57 +++-- 11 files changed, 204 insertions(+), 157 deletions(-) create mode 100644 sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py delete mode 100644 sqlmesh/migrations/v0093_use_unrendered_query_in_fingerprint.py diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py index ae15efd574..561ee539f6 100644 --- a/sqlmesh/core/audit/definition.py +++ b/sqlmesh/core/audit/definition.py @@ -78,9 +78,14 @@ def query(self) -> t.Union[exp.Query, d.JinjaQuery]: @property def expressions(self) -> t.List[exp.Expression]: - if self.expressions_: - return [e.parse(self.dialect) for e in self.expressions_] - return [] + if not self.expressions_: + return [] + result = [] + for e in self.expressions_: + parsed = e.parse(self.dialect) + if not isinstance(parsed, exp.Semicolon): + result.append(parsed) + return result @property def macro_definitions(self) -> t.List[d.MacroDef]: @@ -281,8 +286,8 @@ def metadata_hash(self) -> str: self.cron_tz.key if self.cron_tz else None, ] - query = self.render_audit_query() or self.query - data.append(gen(query)) + data.append(self.query_.sql) + data.extend([e.sql for e in self.expressions_ or []]) self._metadata_hash = hash_data(data) return self._metadata_hash diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 19e513c3b2..0a55f80cee 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -616,11 +616,6 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any: expression_validator: t.Callable = field_validator( - # "query", - # "expressions_", - "pre_statements_", - "post_statements_", - "on_virtual_update_", "unique_key", mode="before", check_fields=False, @@ -719,9 +714,9 @@ def _validate_parsable_sql( return field_validator( "query_", "expressions_", - # "pre_statements_", - # "post_statements_", - # "on_virtual_update_", + "pre_statements_", + "post_statements_", + "on_virtual_update_", mode="before", check_fields=False, )(_validate_parsable_sql) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 7b091c3781..73c2e8d472 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -25,7 +25,6 @@ from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.model.common import ( ParsableSql, - expression_validator, make_python_env, parse_dependencies, parse_strings_with_macro_refs, @@ -152,21 +151,17 @@ class _Model(ModelMeta, frozen=True): audit_definitions: t.Dict[str, ModelAudit] = {} mapping_schema: t.Dict[str, t.Any] = {} extract_dependencies_from_query: bool = True + pre_statements_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="pre_statements") + post_statements_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="post_statements") + on_virtual_update_: t.Optional[t.List[ParsableSql]] = Field( + default=None, alias="on_virtual_update" + ) _full_depends_on: t.Optional[t.Set[str]] = None _statement_renderer_cache: t.Dict[int, ExpressionRenderer] = {} + _is_metadata_only_change_cache: t.Dict[int, bool] = {} - pre_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="pre_statements" - ) - post_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="post_statements" - ) - on_virtual_update_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="on_virtual_update" - ) - - _expressions_validator = expression_validator + _expressions_validator = ParsableSql.validator() def __getstate__(self) -> t.Dict[t.Any, t.Any]: state = super().__getstate__() @@ -545,15 +540,15 @@ def render_audit_query( @property def pre_statements(self) -> t.List[exp.Expression]: - return self.pre_statements_ or [] + return self._get_statements("pre_statements_") @property def post_statements(self) -> t.List[exp.Expression]: - return self.post_statements_ or [] + return self._get_statements("post_statements_") @property def on_virtual_update(self) -> t.List[exp.Expression]: - return self.on_virtual_update_ or [] + return self._get_statements("on_virtual_update_") @property def macro_definitions(self) -> t.List[d.MacroDef]: @@ -564,6 +559,17 @@ def macro_definitions(self) -> t.List[d.MacroDef]: if isinstance(s, d.MacroDef) ] + def _get_statements(self, attr_name: str) -> t.List[exp.Expression]: + value = getattr(self, attr_name) + if not value: + return [] + result = [] + for v in value: + parsed = v.parse(self.dialect) + if not isinstance(parsed, exp.Semicolon): + result.append(parsed) + return result + def _render_statements( self, statements: t.Iterable[exp.Expression], @@ -1027,6 +1033,45 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: """ raise NotImplementedError + def is_metadata_only_change(self, other: _Node) -> bool: + if self._is_metadata_only_change_cache.get(id(other), None) is not None: + return self._is_metadata_only_change_cache[id(other)] + + is_metadata_change = True + if ( + not isinstance(other, _Model) + or self.metadata_hash == other.metadata_hash + or self._data_hash_values_no_sql != other._data_hash_values_no_sql + ): + is_metadata_change = False + else: + this_statements = [ + s + for s in [*self.pre_statements, *self.post_statements] + if not self._is_metadata_statement(s) + ] + other_statements = [ + s + for s in [*other.pre_statements, *other.post_statements] + if not other._is_metadata_statement(s) + ] + if len(this_statements) != len(other_statements): + is_metadata_change = False + else: + for this_statement, other_statement in zip(this_statements, other_statements): + this_rendered = ( + self._statement_renderer(this_statement).render() or this_statement.sql + ) + other_rendered = ( + other._statement_renderer(other_statement).render() or other_statement.sql + ) + if this_rendered != other_rendered: + is_metadata_change = False + break + + self._is_metadata_only_change_cache[id(other)] = is_metadata_change + return is_metadata_change + @property def data_hash(self) -> str: """ @@ -1041,6 +1086,19 @@ def data_hash(self) -> str: @property def _data_hash_values(self) -> t.List[str]: + return self._data_hash_values_no_sql + self._data_hash_values_sql + + @property + def _data_hash_values_sql(self) -> t.List[str]: + data = [] + + for statement in [*(self.pre_statements_ or []), *(self.post_statements_ or [])]: + data.append(statement.sql) + + return data + + @property + def _data_hash_values_no_sql(self) -> t.List[str]: data = [ str( # Exclude metadata only macro funcs [(k, v) for k, v in self.sorted_python_env if not v.is_metadata] @@ -1068,18 +1126,6 @@ def _data_hash_values(self) -> t.List[str]: data.append(key) data.append(gen(value)) - for statement in (*self.pre_statements, *self.post_statements): - statement_exprs: t.List[exp.Expression] = [] - if not isinstance(statement, d.MacroDef): - rendered = self._statement_renderer(statement).render() - if self._is_metadata_statement(statement): - continue - if rendered: - statement_exprs = rendered - else: - statement_exprs = [statement] - data.extend(gen(e) for e in statement_exprs) - return data # type: ignore def _audit_metadata_hash_values(self) -> t.List[str]: @@ -1095,13 +1141,9 @@ def _audit_metadata_hash_values(self) -> t.List[str]: metadata.append(gen(arg_value)) else: audit = self.audit_definitions[audit_name] - query = ( - self.render_audit_query(audit, **t.cast(t.Dict[str, t.Any], audit_args)) - or audit.query - ) metadata.extend( [ - gen(query), + audit.query_.sql, audit.dialect, str(audit.skip), str(audit.blocking), @@ -1172,12 +1214,9 @@ def _additional_metadata(self) -> t.List[str]: if metadata_only_macros: additional_metadata.append(str(metadata_only_macros)) - for statement in (*self.pre_statements, *self.post_statements): - if self._is_metadata_statement(statement): - additional_metadata.append(gen(statement)) - - for statement in self.on_virtual_update: - additional_metadata.append(gen(statement)) + for statements in [self.pre_statements_, self.post_statements_, self.on_virtual_update_]: + for statement in statements or []: + additional_metadata.append(statement.sql) return additional_metadata @@ -1279,9 +1318,7 @@ class SqlModel(_Model): query_: ParsableSql = Field(alias="query") source_type: t.Literal["sql"] = "sql" - _query_validator = ParsableSql.validator() _columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - _is_metadata_only_change_cache: t.Dict[int, bool] = {} def __getstate__(self) -> t.Dict[t.Any, t.Any]: state = super().__getstate__() @@ -1513,19 +1550,16 @@ def is_metadata_only_change(self, previous: _Node) -> bool: if self._is_metadata_only_change_cache.get(id(previous), None) is not None: return self._is_metadata_only_change_cache[id(previous)] - if ( - not isinstance(previous, SqlModel) - or self.metadata_hash == previous.metadata_hash - or self._data_hash_values_no_query != previous._data_hash_values_no_query - ): - is_metadata_change = False - else: - # If the rendered queries are the same, then this is a metadata only change - this_rendered_query = self.render_query() - previous_rendered_query = previous.render_query() - is_metadata_change = ( - this_rendered_query is not None and this_rendered_query == previous_rendered_query - ) + if not super().is_metadata_only_change(previous): + return False + + if not isinstance(previous, SqlModel): + self._is_metadata_only_change_cache[id(previous)] = False + return False + + this_rendered_query = self.render_query() or self.query + previous_rendered_query = previous.render_query() or previous.query + is_metadata_change = this_rendered_query == previous_rendered_query self._is_metadata_only_change_cache[id(previous)] = is_metadata_change return is_metadata_change @@ -1549,22 +1583,22 @@ def _query_renderer(self) -> QueryRenderer: ) @property - def _data_hash_values_no_query(self) -> t.List[str]: + def _data_hash_values_no_sql(self) -> t.List[str]: return [ - *super()._data_hash_values, + *super()._data_hash_values_no_sql, *self.jinja_macros.data_hash_values, ] @property - def _data_hash_values(self) -> t.List[str]: + def _data_hash_values_sql(self) -> t.List[str]: return [ - *self._data_hash_values_no_query, - gen(self.query, comments=False), + *super()._data_hash_values_sql, + self.query_.sql, ] @property def _additional_metadata(self) -> t.List[str]: - return [*super()._additional_metadata, gen(self.query, comments=True)] + return [*super()._additional_metadata, self.query_.sql] @property def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: @@ -1788,8 +1822,8 @@ def _reader(self) -> CsvSeedReader: return self.seed.reader(dialect=self.dialect, settings=self.kind.csv_settings) @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_sql(self) -> t.List[str]: + data = super()._data_hash_values_no_sql for column_name, column_hash in self.column_hashes.items(): data.append(column_name) data.append(column_hash) @@ -1882,8 +1916,8 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: return None @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_sql(self) -> t.List[str]: + data = super()._data_hash_values_no_sql data.append(self.entrypoint) return data @@ -2287,7 +2321,6 @@ def load_sql_based_model( def create_sql_model( name: TableName, query: t.Optional[exp.Expression], - dialect: t.Optional[str] = None, **kwargs: t.Any, ) -> Model: """Creates a SQL model. @@ -2304,14 +2337,7 @@ def create_sql_model( ) assert isinstance(query, (exp.Query, d.JinjaQuery, d.MacroFunc)) - dialect = dialect or "" - return _create_model( - SqlModel, - name, - query=ParsableSql.from_parsed_expression(query, dialect, use_meta_sql=True), - dialect=dialect, - **kwargs, - ) + return _create_model(SqlModel, name, query=query, **kwargs) def create_seed_model( @@ -2525,34 +2551,26 @@ def _create_model( statements: t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]] = [] - # Merge default pre_statements with model-specific pre_statements - if "pre_statements" in defaults: - kwargs["pre_statements"] = [ - exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["pre_statements"] - ] + kwargs.get("pre_statements", []) - - # Merge default post_statements with model-specific post_statements - if "post_statements" in defaults: - kwargs["post_statements"] = [ - exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["post_statements"] - ] + kwargs.get("post_statements", []) - - # Merge default on_virtual_update with model-specific on_virtual_update - if "on_virtual_update" in defaults: - kwargs["on_virtual_update"] = [ - exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["on_virtual_update"] - ] + kwargs.get("on_virtual_update", []) - - if "pre_statements" in kwargs: - statements.extend(kwargs["pre_statements"]) if "query" in kwargs: - statements.append(kwargs["query"].parse(dialect)) - if "post_statements" in kwargs: - statements.extend(kwargs["post_statements"]) + statements.append(kwargs["query"]) + kwargs["query"] = ParsableSql.from_parsed_expression( + kwargs["query"], dialect, use_meta_sql=True + ) - # Macros extracted from these statements need to be treated as metadata only - if "on_virtual_update" in kwargs: - statements.extend((stmt, True) for stmt in kwargs["on_virtual_update"]) + # Merge default statements with model-specific statements + for statement_field in ["pre_statements", "post_statements", "on_virtual_update"]: + if statement_field in defaults: + kwargs[statement_field] = [ + exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults[statement_field] + ] + kwargs.get(statement_field, []) + if statement_field in kwargs: + # Macros extracted from these statements need to be treated as metadata only + is_metadata = statement_field == "on_virtual_update" + statements.extend((stmt, is_metadata) for stmt in kwargs[statement_field]) + kwargs[statement_field] = [ + ParsableSql.from_parsed_expression(stmt, dialect, use_meta_sql=True) + for stmt in kwargs[statement_field] + ] # This is done to allow variables like @gateway to be used in these properties # since rendering shifted from load time to run time. diff --git a/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py b/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py new file mode 100644 index 0000000000..53d4cb1727 --- /dev/null +++ b/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py @@ -0,0 +1,5 @@ +"""Use the raw SQL when computing the model fingerprint.""" + + +def migrate(state_sync, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0093_use_unrendered_query_in_fingerprint.py b/sqlmesh/migrations/v0093_use_unrendered_query_in_fingerprint.py deleted file mode 100644 index 8698d32831..0000000000 --- a/sqlmesh/migrations/v0093_use_unrendered_query_in_fingerprint.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Use the unrendered query when computing the model fingerprint.""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index cd36049a0d..c22e904374 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -6670,11 +6670,12 @@ def change_data_type( assert model is not None if isinstance(model, SqlModel): - data_types = model.query.find_all(DataType) + query = model.query.copy() + data_types = query.find_all(DataType) for data_type in data_types: if data_type.this == old_type: data_type.set("this", new_type) - context.upsert_model(model_name, query_=model.query_) + context.upsert_model(model_name, query_=ParsableSql(sql=query.sql(dialect=model.dialect))) elif model.columns_to_types_ is not None: for k, v in model.columns_to_types_.items(): if v.this == old_type: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 8db1ca593a..72de783076 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -5733,7 +5733,7 @@ def test_default_catalog_sql(assert_exp_eq): The system is not designed to actually support having an engine that doesn't support default catalog to start supporting it or the reverse of that. If that did happen then bugs would occur. """ - HASH_WITH_CATALOG = "3443912775" + HASH_WITH_CATALOG = "2768215345" # Test setting default catalog doesn't change hash if it matches existing logic expressions = d.parse( @@ -8309,15 +8309,9 @@ def noop(evaluator) -> None: new_model = load_sql_based_model( expressions, path=Path("./examples/sushi/models/test_model.sql") ) - if metadata_only: - assert "noop" not in new_model._data_hash_values[0] - assert "noop" in new_model._additional_metadata[0] - assert model.data_hash == new_model.data_hash - assert model.metadata_hash != new_model.metadata_hash - else: - assert "noop" in new_model._data_hash_values[0] - assert model.data_hash != new_model.data_hash - assert model.metadata_hash == new_model.metadata_hash + assert model.metadata_hash != new_model.metadata_hash + assert model.data_hash != new_model.data_hash + assert new_model.is_metadata_only_change(model) == metadata_only @macro(metadata_only=metadata_only) # type: ignore def noop(evaluator) -> None: @@ -8337,6 +8331,7 @@ def noop(evaluator) -> None: assert "print" in updated_model._data_hash_values[0] assert new_model.data_hash != updated_model.data_hash assert new_model.metadata_hash == updated_model.metadata_hash + assert updated_model.is_metadata_only_change(new_model) == metadata_only def test_managed_kind_sql(): @@ -10732,7 +10727,7 @@ def f(): Context(paths=tmp_path, config=config) -def test_semicolon_is_not_included_in_model_state(tmp_path, assert_exp_eq): +def test_semicolon_is_metadata_only_change(tmp_path, assert_exp_eq): init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) db_connection = DuckDBConnectionConfig(database=str(tmp_path / "db.db")) @@ -10821,7 +10816,9 @@ def test_semicolon_is_not_included_in_model_state(tmp_path, assert_exp_eq): ctx.load() plan = ctx.plan(no_prompts=True, auto_apply=True) - assert not plan.context_diff.modified_snapshots + assert len(plan.context_diff.modified_snapshots) == 1 + assert len(plan.new_snapshots) == 1 + assert plan.new_snapshots[0].is_metadata def test_invalid_audit_reference(): diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 81742f9a67..d63b642f60 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -915,8 +915,8 @@ def test_fingerprint(model: Model, parent_model: Model): fingerprint = fingerprint_from_node(model, nodes={}) original_fingerprint = SnapshotFingerprint( - data_hash="1698409777", - metadata_hash="3575333731", + data_hash="2406542604", + metadata_hash="3341445192", ) assert fingerprint == original_fingerprint @@ -943,7 +943,7 @@ def test_fingerprint(model: Model, parent_model: Model): model = SqlModel(**{**model.dict(), "query": parse_one("select 1, ds -- annotation")}) fingerprint = fingerprint_from_node(model, nodes={}) assert new_fingerprint != fingerprint - assert new_fingerprint.data_hash == fingerprint.data_hash + assert new_fingerprint.data_hash != fingerprint.data_hash assert new_fingerprint.metadata_hash != fingerprint.metadata_hash model = SqlModel( @@ -953,14 +953,14 @@ def test_fingerprint(model: Model, parent_model: Model): assert new_fingerprint != fingerprint assert new_fingerprint.data_hash != fingerprint.data_hash assert new_fingerprint.metadata_hash != fingerprint.metadata_hash - assert fingerprint.metadata_hash == original_fingerprint.metadata_hash + assert fingerprint.metadata_hash != original_fingerprint.metadata_hash model = SqlModel(**{**original_model.dict(), "post_statements": [parse_one("DROP TABLE test")]}) fingerprint = fingerprint_from_node(model, nodes={}) assert new_fingerprint != fingerprint assert new_fingerprint.data_hash != fingerprint.data_hash assert new_fingerprint.metadata_hash != fingerprint.metadata_hash - assert fingerprint.metadata_hash == original_fingerprint.metadata_hash + assert fingerprint.metadata_hash != original_fingerprint.metadata_hash def test_fingerprint_seed_model(): @@ -1015,8 +1015,8 @@ def test_fingerprint_jinja_macros(model: Model): } ) original_fingerprint = SnapshotFingerprint( - data_hash="343517722", - metadata_hash="3575333731", + data_hash="93332825", + metadata_hash="3341445192", ) fingerprint = fingerprint_from_node(model, nodes={}) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 60908ed7c4..9b1e81c0f4 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -252,6 +252,9 @@ def increment_stage_counter(evaluator) -> None: snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot.model.render_pre_statements() + assert f"RuntimeStage value: {RuntimeStage.LOADING.value}" in capsys.readouterr().out evaluator.create([snapshot], {}) diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index 73fd37a2f7..839cbb415e 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -10,6 +10,7 @@ from sqlmesh.core.context import Context from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig, DuckDBConnectionConfig from sqlmesh.core.model import SqlModel, load_sql_based_model +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.table_diff import TableDiff, SchemaDiff import numpy as np # noqa: TID253 from sqlmesh.utils.errors import SQLMeshError @@ -48,8 +49,14 @@ def capture_console_output(method_name: str, **kwargs) -> str: def test_data_diff(sushi_context_fixed_date, capsys, caplog): model = sushi_context_fixed_date.models['"memory"."sushi"."customer_revenue_by_day"'] - model.query.select(exp.cast("'1'", "VARCHAR").as_("modified_col"), "1 AS y", copy=False) - sushi_context_fixed_date.upsert_model(model) + sushi_context_fixed_date.upsert_model( + model, + query_=ParsableSql( + sql=model.query.select(exp.cast("'1'", "VARCHAR").as_("modified_col"), "1 AS y").sql( + model.dialect + ) + ), + ) sushi_context_fixed_date.plan( "source_dev", diff --git a/tests/integrations/github/cicd/test_integration.py b/tests/integrations/github/cicd/test_integration.py index f78419889d..ce357f6d36 100644 --- a/tests/integrations/github/cicd/test_integration.py +++ b/tests/integrations/github/cicd/test_integration.py @@ -16,6 +16,7 @@ from sqlmesh.core.config import CategorizerConfig, Config, ModelDefaultsConfig, LinterConfig from sqlmesh.core.engine_adapter.shared import DataObject from sqlmesh.core.user import User, UserRole +from sqlmesh.core.model.common import ParsableSql from sqlmesh.integrations.github.cicd import command from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig, MergeMethod from sqlmesh.integrations.github.cicd.controller import ( @@ -249,8 +250,10 @@ def test_merge_pr_has_non_breaking_change( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -458,8 +461,10 @@ def test_merge_pr_has_non_breaking_change_diff_start( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -666,8 +671,10 @@ def test_merge_pr_has_non_breaking_change_no_categorization( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -983,8 +990,10 @@ def test_no_merge_since_no_deploy_signal( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -1183,8 +1192,10 @@ def test_no_merge_since_no_deploy_signal_no_approvers_defined( controller._context.users = [User(username="test", github_username="test_github", roles=[])] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -1357,8 +1368,10 @@ def test_deploy_comment_pre_categorized( controller._context.users = [User(username="test", github_username="test_github", roles=[])] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) # Manually categorize the change as non-breaking and don't backfill anything controller._context.plan( @@ -1557,8 +1570,12 @@ def test_error_msg_when_applying_plan_with_bug( ] # Make an error by adding a column that doesn't exist model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("non_existing_col", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql( + sql=model.query.select(exp.alias_("non_existing_col", "new_col")).sql(model.dialect) + ), + ) github_output_file = tmp_path / "github_output.txt" @@ -1716,8 +1733,10 @@ def test_overlapping_changes_models( # These changes have shared children and this ensures we don't repeat the children in the output # Make a non-breaking change model = controller._context.get_model("sushi.customers").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) # Make a breaking change model = controller._context.get_model("sushi.waiter_names").copy() @@ -2283,8 +2302,10 @@ def test_has_required_approval_but_not_base_branch( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" From 7156542c643314d784d52ce677743176ac64e94d Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 29 Aug 2025 14:18:07 -0700 Subject: [PATCH 6/8] cosmetic --- sqlmesh/core/model/definition.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 73c2e8d472..fedbe4a9f9 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -540,15 +540,15 @@ def render_audit_query( @property def pre_statements(self) -> t.List[exp.Expression]: - return self._get_statements("pre_statements_") + return self._get_parsed_statements("pre_statements_") @property def post_statements(self) -> t.List[exp.Expression]: - return self._get_statements("post_statements_") + return self._get_parsed_statements("post_statements_") @property def on_virtual_update(self) -> t.List[exp.Expression]: - return self._get_statements("on_virtual_update_") + return self._get_parsed_statements("on_virtual_update_") @property def macro_definitions(self) -> t.List[d.MacroDef]: @@ -559,7 +559,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]: if isinstance(s, d.MacroDef) ] - def _get_statements(self, attr_name: str) -> t.List[exp.Expression]: + def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]: value = getattr(self, attr_name) if not value: return [] @@ -1060,10 +1060,10 @@ def is_metadata_only_change(self, other: _Node) -> bool: else: for this_statement, other_statement in zip(this_statements, other_statements): this_rendered = ( - self._statement_renderer(this_statement).render() or this_statement.sql + self._statement_renderer(this_statement).render() or this_statement ) other_rendered = ( - other._statement_renderer(other_statement).render() or other_statement.sql + other._statement_renderer(other_statement).render() or other_statement ) if this_rendered != other_rendered: is_metadata_change = False @@ -1092,8 +1092,9 @@ def _data_hash_values(self) -> t.List[str]: def _data_hash_values_sql(self) -> t.List[str]: data = [] - for statement in [*(self.pre_statements_ or []), *(self.post_statements_ or [])]: - data.append(statement.sql) + for statements in [self.pre_statements_, self.post_statements_]: + for statement in statements or []: + data.append(statement.sql) return data From 988379c5fb39b2e841898cfa8d09585bad1e0739 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 29 Aug 2025 14:53:10 -0700 Subject: [PATCH 7/8] test original sql --- sqlmesh/core/model/definition.py | 6 +++-- tests/core/test_model.py | 43 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index fedbe4a9f9..f3ffcde05a 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -2297,6 +2297,7 @@ def load_sql_based_model( variables=variables, inline_audits=inline_audits, blueprint_variables=blueprint_variables, + use_original_sql=True, **meta_fields, ) @@ -2519,6 +2520,7 @@ def _create_model( signal_definitions: t.Optional[SignalRegistry] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + use_original_sql: bool = False, **kwargs: t.Any, ) -> Model: validate_extra_and_required_fields( @@ -2555,7 +2557,7 @@ def _create_model( if "query" in kwargs: statements.append(kwargs["query"]) kwargs["query"] = ParsableSql.from_parsed_expression( - kwargs["query"], dialect, use_meta_sql=True + kwargs["query"], dialect, use_meta_sql=use_original_sql ) # Merge default statements with model-specific statements @@ -2569,7 +2571,7 @@ def _create_model( is_metadata = statement_field == "on_virtual_update" statements.extend((stmt, is_metadata) for stmt in kwargs[statement_field]) kwargs[statement_field] = [ - ParsableSql.from_parsed_expression(stmt, dialect, use_meta_sql=True) + ParsableSql.from_parsed_expression(stmt, dialect, use_meta_sql=use_original_sql) for stmt in kwargs[statement_field] ] diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 72de783076..be1df5f2d6 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -11471,3 +11471,46 @@ def test_raw_jinja_raw_tag(): model = load_sql_based_model(expressions) assert model.render_query().sql() == "SELECT '{{ foo }}' AS \"col\"" + + +def test_use_original_sql(): + expressions = d.parse( + """ + MODEL (name test); + + CREATE TABLE pre ( + a INT + ); + + SELECT + 1, + 2; + + CREATE TABLE post ( + b INT + ); + """ + ) + + model = load_sql_based_model(expressions) + assert model.query_.sql == "SELECT\n 1,\n 2" + assert model.pre_statements_[0].sql == "CREATE TABLE pre (\n a INT\n )" + assert model.post_statements_[0].sql == "CREATE TABLE post (\n b INT\n );" + + # Now manually create the model and make sure that the original SQL is not used + model_query = d.parse_one("SELECT 1 AS one") + assert model_query.meta["sql"] == "SELECT 1 AS one" + model_query = model_query.select("2 AS two") + + pre_statements = [d.parse_one("CREATE TABLE pre (\n a INT\n )")] + post_statements = [d.parse_one("CREATE TABLE post (\n b INT\n );")] + + model = create_sql_model( + "test", + model_query, + pre_statements=pre_statements, + post_statements=post_statements, + ) + assert model.query_.sql == "SELECT 1 AS one, 2 AS two" + assert model.pre_statements_[0].sql == "CREATE TABLE pre (a INT)" + assert model.post_statements_[0].sql == "CREATE TABLE post (b INT)" From 26226d92a9ede89999ea78e43a22bd575da72947 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 2 Sep 2025 09:18:12 -0700 Subject: [PATCH 8/8] extend the audit load test --- tests/core/test_audit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/core/test_audit.py b/tests/core/test_audit.py index 81335e5f1a..ed67975e9e 100644 --- a/tests/core/test_audit.py +++ b/tests/core/test_audit.py @@ -80,6 +80,8 @@ def test_load(assert_exp_eq): col IS NULL """, ) + assert audit.query_._parsed is not None + assert audit.query_._parsed_dialect == "spark" def test_load_standalone(assert_exp_eq): @@ -121,6 +123,8 @@ def test_load_standalone(assert_exp_eq): col IS NULL """, ) + assert audit.query_._parsed is not None + assert audit.query_._parsed_dialect == "spark" def test_load_standalone_default_catalog(assert_exp_eq):