diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py index 210ae9da1b..561ee539f6 100644 --- a/sqlmesh/core/audit/definition.py +++ b/sqlmesh/core/audit/definition.py @@ -15,11 +15,10 @@ bool_validator, default_catalog_validator, depends_on_validator, - expression_validator, sort_python_env, sorted_python_env_payloads, ) -from sqlmesh.core.model.common import make_python_env, single_value_or_tuple +from sqlmesh.core.model.common import make_python_env, single_value_or_tuple, ParsableSql from sqlmesh.core.node import _Node from sqlmesh.core.renderer import QueryRenderer from sqlmesh.utils.date import TimeLike @@ -67,15 +66,26 @@ class AuditMixin(AuditCommonMetaMixin): jinja_macros: A registry of jinja macros to use when rendering the audit query. """ - query: t.Union[exp.Query, d.JinjaQuery] + query_: ParsableSql defaults: t.Dict[str, exp.Expression] - expressions_: t.Optional[t.List[exp.Expression]] + expressions_: t.Optional[t.List[ParsableSql]] jinja_macros: JinjaMacroRegistry formatting: t.Optional[bool] + @property + def query(self) -> t.Union[exp.Query, d.JinjaQuery]: + return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect)) + @property def expressions(self) -> t.List[exp.Expression]: - return self.expressions_ or [] + if not self.expressions_: + return [] + result = [] + for e in self.expressions_: + parsed = e.parse(self.dialect) + if not isinstance(parsed, exp.Semicolon): + result.append(parsed) + return result @property def macro_definitions(self) -> t.List[d.MacroDef]: @@ -122,16 +132,16 @@ class ModelAudit(PydanticModel, AuditMixin, frozen=True): skip: bool = False blocking: bool = True standalone: t.Literal[False] = False - query: t.Union[exp.Query, d.JinjaQuery] + query_: ParsableSql = Field(alias="query") defaults: t.Dict[str, exp.Expression] = {} - expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions") + expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions") jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() formatting: t.Optional[bool] = Field(default=None, exclude=True) _path: t.Optional[Path] = None # Validators - _query_validator = expression_validator + _query_validator = ParsableSql.validator() _bool_validator = bool_validator _string_validator = audit_string_validator _map_validator = audit_map_validator @@ -153,9 +163,9 @@ class StandaloneAudit(_Node, AuditMixin): skip: bool = False blocking: bool = False standalone: t.Literal[True] = True - query: t.Union[exp.Query, d.JinjaQuery] + query_: ParsableSql = Field(alias="query") defaults: t.Dict[str, exp.Expression] = {} - expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions") + expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions") jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() default_catalog: t.Optional[str] = None depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on") @@ -165,7 +175,7 @@ class StandaloneAudit(_Node, AuditMixin): source_type: t.Literal["audit"] = "audit" # Validators - _query_validator = expression_validator + _query_validator = ParsableSql.validator() _bool_validator = bool_validator _string_validator = audit_string_validator _map_validator = audit_map_validator @@ -276,8 +286,8 @@ def metadata_hash(self) -> str: self.cron_tz.key if self.cron_tz else None, ] - query = self.render_audit_query() or self.query - data.append(gen(query)) + data.append(self.query_.sql) + data.extend([e.sql for e in self.expressions_ or []]) self._metadata_hash = hash_data(data) return self._metadata_hash @@ -461,11 +471,17 @@ def load_audit( if project is not None: extra_kwargs["project"] = project - dialect = meta_fields.pop("dialect", dialect) + dialect = meta_fields.pop("dialect", dialect) or "" + + parsable_query = ParsableSql.from_parsed_expression(query, dialect, use_meta_sql=True) + parsable_statements = [ + ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=True) for s in statements + ] + try: audit = audit_class( - query=query, - expressions=statements, + query=parsable_query, + expressions=parsable_statements, dialect=dialect, **extra_kwargs, **meta_fields, diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py index 12da39f50f..07d13b1c2f 100644 --- a/sqlmesh/core/context_diff.py +++ b/sqlmesh/core/context_diff.py @@ -435,7 +435,7 @@ def directly_modified(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return current.fingerprint.data_hash != previous.fingerprint.data_hash + return current.is_directly_modified(previous) def indirectly_modified(self, name: str) -> bool: """Returns whether or not a node was indirectly modified in this context. @@ -451,10 +451,7 @@ def indirectly_modified(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return ( - current.fingerprint.data_hash == previous.fingerprint.data_hash - and current.fingerprint.parent_data_hash != previous.fingerprint.parent_data_hash - ) + return current.is_indirectly_modified(previous) def metadata_updated(self, name: str) -> bool: """Returns whether or not the given node's metadata has been updated. @@ -470,7 +467,7 @@ def metadata_updated(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return current.fingerprint.metadata_hash != previous.fingerprint.metadata_hash + return current.is_metadata_updated(previous) def text_diff(self, name: str) -> str: """Finds the difference of a node between the current and remote environment. diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 9a68ec18c0..0a55f80cee 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -21,7 +21,7 @@ prepare_env, serialize_env, ) -from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, get_dialect if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -616,11 +616,6 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any: expression_validator: t.Callable = field_validator( - "query", - "expressions_", - "pre_statements_", - "post_statements_", - "on_virtual_update_", "unique_key", mode="before", check_fields=False, @@ -663,3 +658,65 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any: mode="before", check_fields=False, )(depends_on) + + +class ParsableSql(PydanticModel): + sql: str + + _parsed: t.Optional[exp.Expression] = None + _parsed_dialect: t.Optional[str] = None + + def parse(self, dialect: str) -> exp.Expression: + if self._parsed is None or self._parsed_dialect != dialect: + self._parsed = d.parse_one(self.sql, dialect=dialect) + self._parsed_dialect = dialect + return self._parsed + + @classmethod + def from_parsed_expression( + cls, parsed_expression: exp.Expression, dialect: str, use_meta_sql: bool = False + ) -> ParsableSql: + sql = ( + parsed_expression.meta.get("sql") or parsed_expression.sql(dialect=dialect) + if use_meta_sql + else parsed_expression.sql(dialect=dialect) + ) + result = cls(sql=sql) + result._parsed = parsed_expression + result._parsed_dialect = dialect + return result + + @classmethod + def validator(cls) -> classmethod: + def _validate_parsable_sql( + v: t.Any, info: ValidationInfo + ) -> t.Optional[t.Union[ParsableSql, t.List[ParsableSql]]]: + if v is None: + return v + if isinstance(v, str): + return ParsableSql(sql=v) + if isinstance(v, exp.Expression): + return ParsableSql.from_parsed_expression( + v, get_dialect(info.data), use_meta_sql=False + ) + if isinstance(v, list): + dialect = get_dialect(info.data) + return [ + ParsableSql(sql=s) + if isinstance(s, str) + else ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=False) + if isinstance(s, exp.Expression) + else ParsableSql.parse_obj(s) + for s in v + ] + return ParsableSql.parse_obj(v) + + return field_validator( + "query_", + "expressions_", + "pre_statements_", + "post_statements_", + "on_virtual_update_", + mode="before", + check_fields=False, + )(_validate_parsable_sql) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index dba8eedc31..f3ffcde05a 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -24,7 +24,7 @@ from sqlmesh.core.node import IntervalUnit from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.model.common import ( - expression_validator, + ParsableSql, make_python_env, parse_dependencies, parse_strings_with_macro_refs, @@ -62,6 +62,7 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType + from sqlmesh.core.node import _Node from sqlmesh.core._typing import Self, TableName, SessionProperties from sqlmesh.core.context import ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter @@ -150,21 +151,17 @@ class _Model(ModelMeta, frozen=True): audit_definitions: t.Dict[str, ModelAudit] = {} mapping_schema: t.Dict[str, t.Any] = {} extract_dependencies_from_query: bool = True + pre_statements_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="pre_statements") + post_statements_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="post_statements") + on_virtual_update_: t.Optional[t.List[ParsableSql]] = Field( + default=None, alias="on_virtual_update" + ) _full_depends_on: t.Optional[t.Set[str]] = None _statement_renderer_cache: t.Dict[int, ExpressionRenderer] = {} + _is_metadata_only_change_cache: t.Dict[int, bool] = {} - pre_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="pre_statements" - ) - post_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="post_statements" - ) - on_virtual_update_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="on_virtual_update" - ) - - _expressions_validator = expression_validator + _expressions_validator = ParsableSql.validator() def __getstate__(self) -> t.Dict[t.Any, t.Any]: state = super().__getstate__() @@ -543,15 +540,15 @@ def render_audit_query( @property def pre_statements(self) -> t.List[exp.Expression]: - return self.pre_statements_ or [] + return self._get_parsed_statements("pre_statements_") @property def post_statements(self) -> t.List[exp.Expression]: - return self.post_statements_ or [] + return self._get_parsed_statements("post_statements_") @property def on_virtual_update(self) -> t.List[exp.Expression]: - return self.on_virtual_update_ or [] + return self._get_parsed_statements("on_virtual_update_") @property def macro_definitions(self) -> t.List[d.MacroDef]: @@ -562,6 +559,17 @@ def macro_definitions(self) -> t.List[d.MacroDef]: if isinstance(s, d.MacroDef) ] + def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]: + value = getattr(self, attr_name) + if not value: + return [] + result = [] + for v in value: + parsed = v.parse(self.dialect) + if not isinstance(parsed, exp.Semicolon): + result.append(parsed) + return result + def _render_statements( self, statements: t.Iterable[exp.Expression], @@ -1025,6 +1033,45 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: """ raise NotImplementedError + def is_metadata_only_change(self, other: _Node) -> bool: + if self._is_metadata_only_change_cache.get(id(other), None) is not None: + return self._is_metadata_only_change_cache[id(other)] + + is_metadata_change = True + if ( + not isinstance(other, _Model) + or self.metadata_hash == other.metadata_hash + or self._data_hash_values_no_sql != other._data_hash_values_no_sql + ): + is_metadata_change = False + else: + this_statements = [ + s + for s in [*self.pre_statements, *self.post_statements] + if not self._is_metadata_statement(s) + ] + other_statements = [ + s + for s in [*other.pre_statements, *other.post_statements] + if not other._is_metadata_statement(s) + ] + if len(this_statements) != len(other_statements): + is_metadata_change = False + else: + for this_statement, other_statement in zip(this_statements, other_statements): + this_rendered = ( + self._statement_renderer(this_statement).render() or this_statement + ) + other_rendered = ( + other._statement_renderer(other_statement).render() or other_statement + ) + if this_rendered != other_rendered: + is_metadata_change = False + break + + self._is_metadata_only_change_cache[id(other)] = is_metadata_change + return is_metadata_change + @property def data_hash(self) -> str: """ @@ -1039,6 +1086,20 @@ def data_hash(self) -> str: @property def _data_hash_values(self) -> t.List[str]: + return self._data_hash_values_no_sql + self._data_hash_values_sql + + @property + def _data_hash_values_sql(self) -> t.List[str]: + data = [] + + for statements in [self.pre_statements_, self.post_statements_]: + for statement in statements or []: + data.append(statement.sql) + + return data + + @property + def _data_hash_values_no_sql(self) -> t.List[str]: data = [ str( # Exclude metadata only macro funcs [(k, v) for k, v in self.sorted_python_env if not v.is_metadata] @@ -1066,18 +1127,6 @@ def _data_hash_values(self) -> t.List[str]: data.append(key) data.append(gen(value)) - for statement in (*self.pre_statements, *self.post_statements): - statement_exprs: t.List[exp.Expression] = [] - if not isinstance(statement, d.MacroDef): - rendered = self._statement_renderer(statement).render() - if self._is_metadata_statement(statement): - continue - if rendered: - statement_exprs = rendered - else: - statement_exprs = [statement] - data.extend(gen(e) for e in statement_exprs) - return data # type: ignore def _audit_metadata_hash_values(self) -> t.List[str]: @@ -1093,13 +1142,9 @@ def _audit_metadata_hash_values(self) -> t.List[str]: metadata.append(gen(arg_value)) else: audit = self.audit_definitions[audit_name] - query = ( - self.render_audit_query(audit, **t.cast(t.Dict[str, t.Any], audit_args)) - or audit.query - ) metadata.extend( [ - gen(query), + audit.query_.sql, audit.dialect, str(audit.skip), str(audit.blocking), @@ -1170,12 +1215,9 @@ def _additional_metadata(self) -> t.List[str]: if metadata_only_macros: additional_metadata.append(str(metadata_only_macros)) - for statement in (*self.pre_statements, *self.post_statements): - if self._is_metadata_statement(statement): - additional_metadata.append(gen(statement)) - - for statement in self.on_virtual_update: - additional_metadata.append(gen(statement)) + for statements in [self.pre_statements_, self.post_statements_, self.on_virtual_update_]: + for statement in statements or []: + additional_metadata.append(statement.sql) return additional_metadata @@ -1274,7 +1316,7 @@ class SqlModel(_Model): on_virtual_update: The list of SQL statements to be executed after the virtual update. """ - query: t.Union[exp.Query, d.JinjaQuery, d.MacroFunc] + query_: ParsableSql = Field(alias="query") source_type: t.Literal["sql"] = "sql" _columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None @@ -1298,6 +1340,11 @@ def copy(self, **kwargs: t.Any) -> Self: model._full_depends_on = None return model + @property + def query(self) -> t.Union[exp.Query, d.JinjaQuery, d.MacroFunc]: + parsed_query = self.query_.parse(self.dialect) + return t.cast(t.Union[exp.Query, d.JinjaQuery, d.MacroFunc], parsed_query) + def render_query( self, *, @@ -1500,6 +1547,24 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: return False + def is_metadata_only_change(self, previous: _Node) -> bool: + if self._is_metadata_only_change_cache.get(id(previous), None) is not None: + return self._is_metadata_only_change_cache[id(previous)] + + if not super().is_metadata_only_change(previous): + return False + + if not isinstance(previous, SqlModel): + self._is_metadata_only_change_cache[id(previous)] = False + return False + + this_rendered_query = self.render_query() or self.query + previous_rendered_query = previous.render_query() or previous.query + is_metadata_change = this_rendered_query == previous_rendered_query + + self._is_metadata_only_change_cache[id(previous)] = is_metadata_change + return is_metadata_change + @cached_property def _query_renderer(self) -> QueryRenderer: no_quote_identifiers = self.kind.is_view and self.dialect in ("trino", "spark") @@ -1519,17 +1584,22 @@ def _query_renderer(self) -> QueryRenderer: ) @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_sql(self) -> t.List[str]: + return [ + *super()._data_hash_values_no_sql, + *self.jinja_macros.data_hash_values, + ] - query = self.render_query() or self.query - data.append(gen(query)) - data.extend(self.jinja_macros.data_hash_values) - return data + @property + def _data_hash_values_sql(self) -> t.List[str]: + return [ + *super()._data_hash_values_sql, + self.query_.sql, + ] @property def _additional_metadata(self) -> t.List[str]: - return [*super()._additional_metadata, gen(self.query)] + return [*super()._additional_metadata, self.query_.sql] @property def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: @@ -1753,8 +1823,8 @@ def _reader(self) -> CsvSeedReader: return self.seed.reader(dialect=self.dialect, settings=self.kind.csv_settings) @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_sql(self) -> t.List[str]: + data = super()._data_hash_values_no_sql for column_name, column_hash in self.column_hashes.items(): data.append(column_name) data.append(column_hash) @@ -1847,8 +1917,8 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: return None @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_sql(self) -> t.List[str]: + data = super()._data_hash_values_no_sql data.append(self.entrypoint) return data @@ -2227,6 +2297,7 @@ def load_sql_based_model( variables=variables, inline_audits=inline_audits, blueprint_variables=blueprint_variables, + use_original_sql=True, **meta_fields, ) @@ -2449,6 +2520,7 @@ def _create_model( signal_definitions: t.Optional[SignalRegistry] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + use_original_sql: bool = False, **kwargs: t.Any, ) -> Model: validate_extra_and_required_fields( @@ -2482,34 +2554,26 @@ def _create_model( statements: t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]] = [] - # Merge default pre_statements with model-specific pre_statements - if "pre_statements" in defaults: - kwargs["pre_statements"] = [ - exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["pre_statements"] - ] + kwargs.get("pre_statements", []) - - # Merge default post_statements with model-specific post_statements - if "post_statements" in defaults: - kwargs["post_statements"] = [ - exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["post_statements"] - ] + kwargs.get("post_statements", []) - - # Merge default on_virtual_update with model-specific on_virtual_update - if "on_virtual_update" in defaults: - kwargs["on_virtual_update"] = [ - exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["on_virtual_update"] - ] + kwargs.get("on_virtual_update", []) - - if "pre_statements" in kwargs: - statements.extend(kwargs["pre_statements"]) if "query" in kwargs: statements.append(kwargs["query"]) - if "post_statements" in kwargs: - statements.extend(kwargs["post_statements"]) + kwargs["query"] = ParsableSql.from_parsed_expression( + kwargs["query"], dialect, use_meta_sql=use_original_sql + ) - # Macros extracted from these statements need to be treated as metadata only - if "on_virtual_update" in kwargs: - statements.extend((stmt, True) for stmt in kwargs["on_virtual_update"]) + # Merge default statements with model-specific statements + for statement_field in ["pre_statements", "post_statements", "on_virtual_update"]: + if statement_field in defaults: + kwargs[statement_field] = [ + exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults[statement_field] + ] + kwargs.get(statement_field, []) + if statement_field in kwargs: + # Macros extracted from these statements need to be treated as metadata only + is_metadata = statement_field == "on_virtual_update" + statements.extend((stmt, is_metadata) for stmt in kwargs[statement_field]) + kwargs[statement_field] = [ + ParsableSql.from_parsed_expression(stmt, dialect, use_meta_sql=use_original_sql) + for stmt in kwargs[statement_field] + ] # This is done to allow variables like @gateway to be used in these properties # since rendering shifted from load time to run time. diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index 4f0a66dc2e..ea2264f7fa 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -307,16 +307,6 @@ def batch_concurrency(self) -> t.Optional[int]: """The maximal number of batches that can run concurrently for a backfill.""" return None - @property - def data_hash(self) -> str: - """ - Computes the data hash for the node. - - Returns: - The data hash for the node. - """ - raise NotImplementedError - @property def interval_unit(self) -> IntervalUnit: """Returns the interval unit using which data intervals are computed for this node.""" @@ -332,6 +322,16 @@ def depends_on(self) -> t.Set[str]: def fqn(self) -> str: return self.name + @property + def data_hash(self) -> str: + """ + Computes the data hash for the node. + + Returns: + The data hash for the node. + """ + raise NotImplementedError + @property def metadata_hash(self) -> str: """ @@ -342,6 +342,30 @@ def metadata_hash(self) -> str: """ raise NotImplementedError + def is_metadata_only_change(self, previous: _Node) -> bool: + """Determines if this node is a metadata only change in relation to the `previous` node. + + Args: + previous: The previous node to compare against. + + Returns: + True if this node is a metadata only change, False otherwise. + """ + return self.data_hash == previous.data_hash and self.metadata_hash != previous.metadata_hash + + def is_data_change(self, previous: _Node) -> bool: + """Determines if this node is a data change in relation to the `previous` node. + + Args: + previous: The previous node to compare against. + + Returns: + True if this node is a data change, False otherwise. + """ + return ( + self.data_hash != previous.data_hash or self.metadata_hash != previous.metadata_hash + ) and not self.is_metadata_only_change(previous) + def croniter(self, value: TimeLike) -> CroniterCache: if self._croniter is None: self._croniter = CroniterCache(self.cron, value, tz=self.cron_tz) diff --git a/sqlmesh/core/snapshot/categorizer.py b/sqlmesh/core/snapshot/categorizer.py index 88a1ef37ab..78ea7466ed 100644 --- a/sqlmesh/core/snapshot/categorizer.py +++ b/sqlmesh/core/snapshot/categorizer.py @@ -47,11 +47,12 @@ def categorize_change( if type(new_model) != type(old_model): return default_category - if new.fingerprint.data_hash == old.fingerprint.data_hash: - if new.fingerprint.metadata_hash == old.fingerprint.metadata_hash: - raise SQLMeshError( - f"{new} is unmodified or indirectly modified and should not be categorized" - ) + if new.fingerprint == old.fingerprint: + raise SQLMeshError( + f"{new} is unmodified or indirectly modified and should not be categorized" + ) + + if not new.is_directly_modified(old): if new.fingerprint.parent_data_hash == old.fingerprint.parent_data_hash: return SnapshotChangeCategory.NON_BREAKING return None diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index afc8e06458..dea4ef64e5 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1230,6 +1230,21 @@ def apply_pending_restatement_intervals(self) -> None: ) self.intervals = remove_interval(self.intervals, *pending_restatement_interval) + def is_directly_modified(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot is directly modified in relation to the other snapshot.""" + return self.node.is_data_change(other.node) + + def is_indirectly_modified(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot is indirectly modified in relation to the other snapshot.""" + return ( + self.fingerprint.parent_data_hash != other.fingerprint.parent_data_hash + and not self.node.is_data_change(other.node) + ) + + def is_metadata_updated(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot contains metadata changes in relation to the other snapshot.""" + return self.fingerprint.metadata_hash != other.fingerprint.metadata_hash + @property def physical_schema(self) -> str: if self.physical_schema_ is not None: diff --git a/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py b/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py new file mode 100644 index 0000000000..53d4cb1727 --- /dev/null +++ b/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py @@ -0,0 +1,5 @@ +"""Use the raw SQL when computing the model fingerprint.""" + + +def migrate(state_sync, **kwargs): # type: ignore + pass diff --git a/tests/core/state_sync/test_export_import.py b/tests/core/state_sync/test_export_import.py index 2d20199d33..c303a63e59 100644 --- a/tests/core/state_sync/test_export_import.py +++ b/tests/core/state_sync/test_export_import.py @@ -289,8 +289,8 @@ def test_export_local_state( full_model = next(s for s in snapshots if "full_model" in s["name"]) new_model = next(s for s in snapshots if "new_model" in s["name"]) - assert "'1' as modified" in full_model["node"]["query"] - assert "SELECT 1 as id" in new_model["node"]["query"] + assert "'1' as modified" in full_model["node"]["query"]["sql"] + assert "SELECT 1 as id" in new_model["node"]["query"]["sql"] def test_import_invalid_file(tmp_path: Path, state_sync: StateSync) -> None: diff --git a/tests/core/test_audit.py b/tests/core/test_audit.py index 81335e5f1a..ed67975e9e 100644 --- a/tests/core/test_audit.py +++ b/tests/core/test_audit.py @@ -80,6 +80,8 @@ def test_load(assert_exp_eq): col IS NULL """, ) + assert audit.query_._parsed is not None + assert audit.query_._parsed_dialect == "spark" def test_load_standalone(assert_exp_eq): @@ -121,6 +123,8 @@ def test_load_standalone(assert_exp_eq): col IS NULL """, ) + assert audit.query_._parsed is not None + assert audit.query_._parsed_dialect == "spark" def test_load_standalone_default_catalog(assert_exp_eq): diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 3b7c5bd51d..196889a87c 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -39,6 +39,7 @@ from sqlmesh.core.plan.definition import Plan from sqlmesh.core.macros import MacroEvaluator, RuntimeStage from sqlmesh.core.model import load_sql_based_model, model, SqlModel, Model +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.model.cache import OptimizedQueryCache from sqlmesh.core.renderer import render_statements from sqlmesh.core.model.kind import ModelKindName @@ -2303,7 +2304,10 @@ def test_prompt_if_uncategorized_snapshot(mocker: MockerFixture, tmp_path: Path) incremental_model = context.get_model("sqlmesh_example.incremental_model") incremental_model_query = incremental_model.render_query() new_incremental_model_query = t.cast(exp.Select, incremental_model_query).select("1 AS z") - context.upsert_model("sqlmesh_example.incremental_model", query=new_incremental_model_query) + context.upsert_model( + "sqlmesh_example.incremental_model", + query_=ParsableSql(sql=new_incremental_model_query.sql(dialect=incremental_model.dialect)), + ) mock_console = mocker.Mock() spy_plan = mocker.spy(mock_console, "plan") diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 0e779481fd..c22e904374 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -16,6 +16,7 @@ from pathlib import Path from sqlmesh.core.console import set_console, get_console, TerminalConsole from sqlmesh.core.config.naming import NameInferenceConfig +from sqlmesh.core.model.common import ParsableSql from sqlmesh.utils.concurrency import NodeExecutionFailedError import time_machine from pytest_mock.plugin import MockerFixture @@ -2023,7 +2024,7 @@ def test_dbt_select_star_is_directly_modified(sushi_test_dbt_context: Context): model = context.get_model("sushi.simple_model_a") context.upsert_model( model, - query=d.parse_one("SELECT 1 AS a, 2 AS b"), + query_=ParsableSql(sql="SELECT 1 AS a, 2 AS b"), ) snapshot_a_id = context.get_snapshot("sushi.simple_model_a").snapshot_id # type: ignore @@ -2605,8 +2606,8 @@ def test_unaligned_start_snapshot_with_non_deployable_downstream(init_and_plan_c context.upsert_model(SqlModel.parse_obj(kwargs)) context.upsert_model( downstream_model_name, - query=d.parse_one( - "SELECT customer_id, MAX(revenue) AS max_revenue FROM memory.sushi.customer_revenue_lifetime_new GROUP BY 1" + query_=ParsableSql( + sql="SELECT customer_id, MAX(revenue) AS max_revenue FROM memory.sushi.customer_revenue_lifetime_new GROUP BY 1" ), ) @@ -2637,7 +2638,13 @@ def test_virtual_environment_mode_dev_only(init_and_plan_context: t.Callable): # Make a change in dev original_model = context.get_model("sushi.waiter_revenue_by_day") original_fingerprint = context.get_snapshot(original_model.name).fingerprint - model = original_model.copy(update={"query": original_model.query.order_by("waiter_id")}) + model = original_model.copy( + update={ + "query_": ParsableSql( + sql=original_model.query.order_by("waiter_id").sql(dialect=original_model.dialect) + ) + } + ) model = add_projection_to_model(t.cast(SqlModel, model)) context.upsert_model(model) @@ -4681,12 +4688,12 @@ def test_plan_repairs_unrenderable_snapshot_state( f"name = '{target_snapshot.name}' AND identifier = '{target_snapshot.identifier}'", ) + context.clear_caches() + target_snapshot_in_state = context.state_sync.get_snapshots([target_snapshot.snapshot_id])[ + target_snapshot.snapshot_id + ] + with pytest.raises(Exception): - context_copy = context.copy() - context_copy.clear_caches() - target_snapshot_in_state = context_copy.state_sync.get_snapshots( - [target_snapshot.snapshot_id] - )[target_snapshot.snapshot_id] target_snapshot_in_state.model.render_query_or_raise() # Repair the snapshot by creating a new version of it @@ -4695,11 +4702,11 @@ def test_plan_repairs_unrenderable_snapshot_state( plan_builder = context.plan_builder("prod", forward_only=forward_only) plan = plan_builder.build() - assert plan.directly_modified == {target_snapshot.snapshot_id} if not forward_only: assert target_snapshot.snapshot_id in {i.snapshot_id for i in plan.missing_intervals} - plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING) - plan = plan_builder.build() + assert plan.directly_modified == {target_snapshot.snapshot_id} + plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING) + plan = plan_builder.build() context.apply(plan) @@ -5383,7 +5390,10 @@ def test_auto_categorization(sushi_context: Context): ).fingerprint model = t.cast(SqlModel, sushi_context.get_model("sushi.customers", raise_if_missing=True)) - sushi_context.upsert_model("sushi.customers", query=model.query.select("'foo' AS foo")) # type: ignore + sushi_context.upsert_model( + "sushi.customers", + query_=ParsableSql(sql=model.query.select("'foo' AS foo").sql(dialect=model.dialect)), # type: ignore + ) apply_to_environment(sushi_context, environment) assert ( @@ -5447,7 +5457,13 @@ def test_multi(mocker): model = context.get_model("bronze.a") assert model.project == "repo_1" - context.upsert_model(model.copy(update={"query": model.query.select("'c' AS c")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql(sql=model.query.select("'c' AS c").sql(dialect=model.dialect)) + } + ) + ) plan = context.plan_builder().build() assert set(snapshot.name for snapshot in plan.directly_modified) == { @@ -5615,7 +5631,15 @@ def test_multi_virtual_layer(copy_to_temp_path): model = context.get_model("db_1.first_schema.model_one") - context.upsert_model(model.copy(update={"query": model.query.select("'c' AS extra")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'c' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) plan = context.plan_builder().build() context.apply(plan) @@ -5641,9 +5665,25 @@ def test_multi_virtual_layer(copy_to_temp_path): # Create dev environment with changed models model = context.get_model("db_2.second_schema.model_one") - context.upsert_model(model.copy(update={"query": model.query.select("'d' AS extra")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) model = context.get_model("first_schema.model_two") - context.upsert_model(model.copy(update={"query": model.query.select("'d2' AS col")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d2' AS col").sql(dialect=model.dialect) + ) + } + ) + ) plan = context.plan_builder("dev").build() context.apply(plan) @@ -6630,11 +6670,12 @@ def change_data_type( assert model is not None if isinstance(model, SqlModel): - data_types = model.query.find_all(DataType) + query = model.query.copy() + data_types = query.find_all(DataType) for data_type in data_types: if data_type.this == old_type: data_type.set("this", new_type) - context.upsert_model(model_name, query=model.query) + context.upsert_model(model_name, query_=ParsableSql(sql=query.sql(dialect=model.dialect))) elif model.columns_to_types_ is not None: for k, v in model.columns_to_types_.items(): if v.this == old_type: @@ -6921,7 +6962,15 @@ def test_destroy(copy_to_temp_path): model = context.get_model("db_1.first_schema.model_one") - context.upsert_model(model.copy(update={"query": model.query.select("'c' AS extra")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'c' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) plan = context.plan_builder().build() context.apply(plan) @@ -6932,9 +6981,25 @@ def test_destroy(copy_to_temp_path): # Create dev environment with changed models model = context.get_model("db_2.second_schema.model_one") - context.upsert_model(model.copy(update={"query": model.query.select("'d' AS extra")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) model = context.get_model("first_schema.model_two") - context.upsert_model(model.copy(update={"query": model.query.select("'d2' AS col")})) + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d2' AS col").sql(dialect=model.dialect) + ) + } + ) + ) plan = context.plan_builder("dev").build() context.apply(plan) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 3850e08164..be1df5f2d6 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -22,6 +22,7 @@ from sqlmesh.core import dialect as d from sqlmesh.core.console import get_console from sqlmesh.core.audit import ModelAudit, load_audit +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.config import ( Config, DuckDBConnectionConfig, @@ -5732,7 +5733,7 @@ def test_default_catalog_sql(assert_exp_eq): The system is not designed to actually support having an engine that doesn't support default catalog to start supporting it or the reverse of that. If that did happen then bugs would occur. """ - HASH_WITH_CATALOG = "1269513823" + HASH_WITH_CATALOG = "2768215345" # Test setting default catalog doesn't change hash if it matches existing logic expressions = d.parse( @@ -8308,15 +8309,9 @@ def noop(evaluator) -> None: new_model = load_sql_based_model( expressions, path=Path("./examples/sushi/models/test_model.sql") ) - if metadata_only: - assert "noop" not in new_model._data_hash_values[0] - assert "noop" in new_model._additional_metadata[0] - assert model.data_hash == new_model.data_hash - assert model.metadata_hash != new_model.metadata_hash - else: - assert "noop" in new_model._data_hash_values[0] - assert model.data_hash != new_model.data_hash - assert model.metadata_hash == new_model.metadata_hash + assert model.metadata_hash != new_model.metadata_hash + assert model.data_hash != new_model.data_hash + assert new_model.is_metadata_only_change(model) == metadata_only @macro(metadata_only=metadata_only) # type: ignore def noop(evaluator) -> None: @@ -8336,6 +8331,7 @@ def noop(evaluator) -> None: assert "print" in updated_model._data_hash_values[0] assert new_model.data_hash != updated_model.data_hash assert new_model.metadata_hash == updated_model.metadata_hash + assert updated_model.is_metadata_only_change(new_model) == metadata_only def test_managed_kind_sql(): @@ -8874,7 +8870,9 @@ def test_column_description_metadata_change(): context.upsert_model(model) context.plan(no_prompts=True, auto_apply=True) - context.upsert_model("db.test_model", query=parse_one("SELECT 1 AS id /* description 2 */")) + context.upsert_model( + "db.test_model", query_=ParsableSql(sql="SELECT 1 AS id /* description 2 */") + ) plan = context.plan(no_prompts=True, auto_apply=True) snapshots = list(plan.snapshots.values()) @@ -10729,7 +10727,7 @@ def f(): Context(paths=tmp_path, config=config) -def test_semicolon_is_not_included_in_model_state(tmp_path, assert_exp_eq): +def test_semicolon_is_metadata_only_change(tmp_path, assert_exp_eq): init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) db_connection = DuckDBConnectionConfig(database=str(tmp_path / "db.db")) @@ -10818,7 +10816,9 @@ def test_semicolon_is_not_included_in_model_state(tmp_path, assert_exp_eq): ctx.load() plan = ctx.plan(no_prompts=True, auto_apply=True) - assert not plan.context_diff.modified_snapshots + assert len(plan.context_diff.modified_snapshots) == 1 + assert len(plan.new_snapshots) == 1 + assert plan.new_snapshots[0].is_metadata def test_invalid_audit_reference(): @@ -11471,3 +11471,46 @@ def test_raw_jinja_raw_tag(): model = load_sql_based_model(expressions) assert model.render_query().sql() == "SELECT '{{ foo }}' AS \"col\"" + + +def test_use_original_sql(): + expressions = d.parse( + """ + MODEL (name test); + + CREATE TABLE pre ( + a INT + ); + + SELECT + 1, + 2; + + CREATE TABLE post ( + b INT + ); + """ + ) + + model = load_sql_based_model(expressions) + assert model.query_.sql == "SELECT\n 1,\n 2" + assert model.pre_statements_[0].sql == "CREATE TABLE pre (\n a INT\n )" + assert model.post_statements_[0].sql == "CREATE TABLE post (\n b INT\n );" + + # Now manually create the model and make sure that the original SQL is not used + model_query = d.parse_one("SELECT 1 AS one") + assert model_query.meta["sql"] == "SELECT 1 AS one" + model_query = model_query.select("2 AS two") + + pre_statements = [d.parse_one("CREATE TABLE pre (\n a INT\n )")] + post_statements = [d.parse_one("CREATE TABLE post (\n b INT\n );")] + + model = create_sql_model( + "test", + model_query, + pre_statements=pre_statements, + post_statements=post_statements, + ) + assert model.query_.sql == "SELECT 1 AS one, 2 AS two" + assert model.pre_statements_[0].sql == "CREATE TABLE pre (a INT)" + assert model.post_statements_[0].sql == "CREATE TABLE post (b INT)" diff --git a/tests/core/test_selector.py b/tests/core/test_selector.py index 9f3bc9f698..80b9ef691e 100644 --- a/tests/core/test_selector.py +++ b/tests/core/test_selector.py @@ -11,6 +11,7 @@ from sqlmesh.core.audit import StandaloneAudit from sqlmesh.core.environment import Environment from sqlmesh.core.model import Model, SqlModel +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.selector import Selector from sqlmesh.core.snapshot import SnapshotChangeCategory from sqlmesh.utils import UniqueKeyDict @@ -293,7 +294,13 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot): } local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - local_parent = parent.copy(update={"query": parent.query.select("2 as b", append=False)}) # type: ignore + local_parent = parent.copy( + update={ + "query_": ParsableSql( + sql=parent.query.select("2 as b", append=False).sql(dialect=parent.dialect) # type: ignore + ) + } + ) local_models[local_parent.fqn] = local_parent local_child = child.copy(update={"mapping_schema": {'"db"': {'"parent"': {"b": "INT"}}}}) local_models[local_child.fqn] = local_child @@ -301,7 +308,7 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot): selector = Selector(state_reader_mock, local_models) selected = selector.select_models(["db.parent"], env_name) - assert selected[local_child.fqn].data_hash != child.data_hash + assert selected[local_child.fqn].render_query() != child.render_query() _assert_models_equal( selected, diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 86fb434e33..d63b642f60 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -79,7 +79,7 @@ def parent_model(): name="parent.tbl", kind=dict(time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE), dialect="spark", - query=parse_one("SELECT 1, ds"), + query="SELECT 1, ds", ) @@ -92,7 +92,7 @@ def model(): dialect="spark", cron="1 0 * * *", start="2020-01-01", - query=parse_one("SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl"), + query="SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", ) @@ -148,7 +148,9 @@ def test_json(snapshot: Snapshot): "project": "", "python_env": {}, "owner": "owner", - "query": "SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", + "query": { + "sql": "SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", + }, "jinja_macros": { "create_builtins_module": "sqlmesh.utils.jinja", "global_objs": {}, @@ -186,7 +188,7 @@ def test_json_custom_materialization(make_snapshot: t.Callable): dialect="spark", cron="1 0 * * *", start="2020-01-01", - query=parse_one("SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl"), + query="SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", ) snapshot = make_snapshot( @@ -913,8 +915,8 @@ def test_fingerprint(model: Model, parent_model: Model): fingerprint = fingerprint_from_node(model, nodes={}) original_fingerprint = SnapshotFingerprint( - data_hash="3301649319", - metadata_hash="3575333731", + data_hash="2406542604", + metadata_hash="3341445192", ) assert fingerprint == original_fingerprint @@ -941,7 +943,7 @@ def test_fingerprint(model: Model, parent_model: Model): model = SqlModel(**{**model.dict(), "query": parse_one("select 1, ds -- annotation")}) fingerprint = fingerprint_from_node(model, nodes={}) assert new_fingerprint != fingerprint - assert new_fingerprint.data_hash == fingerprint.data_hash + assert new_fingerprint.data_hash != fingerprint.data_hash assert new_fingerprint.metadata_hash != fingerprint.metadata_hash model = SqlModel( @@ -951,14 +953,14 @@ def test_fingerprint(model: Model, parent_model: Model): assert new_fingerprint != fingerprint assert new_fingerprint.data_hash != fingerprint.data_hash assert new_fingerprint.metadata_hash != fingerprint.metadata_hash - assert fingerprint.metadata_hash == original_fingerprint.metadata_hash + assert fingerprint.metadata_hash != original_fingerprint.metadata_hash model = SqlModel(**{**original_model.dict(), "post_statements": [parse_one("DROP TABLE test")]}) fingerprint = fingerprint_from_node(model, nodes={}) assert new_fingerprint != fingerprint assert new_fingerprint.data_hash != fingerprint.data_hash assert new_fingerprint.metadata_hash != fingerprint.metadata_hash - assert fingerprint.metadata_hash == original_fingerprint.metadata_hash + assert fingerprint.metadata_hash != original_fingerprint.metadata_hash def test_fingerprint_seed_model(): @@ -1013,8 +1015,8 @@ def test_fingerprint_jinja_macros(model: Model): } ) original_fingerprint = SnapshotFingerprint( - data_hash="2908339239", - metadata_hash="3575333731", + data_hash="93332825", + metadata_hash="3341445192", ) fingerprint = fingerprint_from_node(model, nodes={}) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 60908ed7c4..9b1e81c0f4 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -252,6 +252,9 @@ def increment_stage_counter(evaluator) -> None: snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot.model.render_pre_statements() + assert f"RuntimeStage value: {RuntimeStage.LOADING.value}" in capsys.readouterr().out evaluator.create([snapshot], {}) diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index 73fd37a2f7..839cbb415e 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -10,6 +10,7 @@ from sqlmesh.core.context import Context from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig, DuckDBConnectionConfig from sqlmesh.core.model import SqlModel, load_sql_based_model +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.table_diff import TableDiff, SchemaDiff import numpy as np # noqa: TID253 from sqlmesh.utils.errors import SQLMeshError @@ -48,8 +49,14 @@ def capture_console_output(method_name: str, **kwargs) -> str: def test_data_diff(sushi_context_fixed_date, capsys, caplog): model = sushi_context_fixed_date.models['"memory"."sushi"."customer_revenue_by_day"'] - model.query.select(exp.cast("'1'", "VARCHAR").as_("modified_col"), "1 AS y", copy=False) - sushi_context_fixed_date.upsert_model(model) + sushi_context_fixed_date.upsert_model( + model, + query_=ParsableSql( + sql=model.query.select(exp.cast("'1'", "VARCHAR").as_("modified_col"), "1 AS y").sql( + model.dialect + ) + ), + ) sushi_context_fixed_date.plan( "source_dev", diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 1b5425068f..d889c7bb33 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -29,6 +29,7 @@ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.macros import MacroEvaluator, macro from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.core.test.context import TestExecutionContext @@ -1985,12 +1986,18 @@ def test_test_generation(tmp_path: Path) -> None: ) context = Context(paths=tmp_path, config=config) - query = context.get_model("sqlmesh_example.full_model").render_query() + model = context.get_model("sqlmesh_example.full_model") + query = model.render_query() assert isinstance(query, exp.Query) context.upsert_model( "sqlmesh_example.full_model", - query=exp.select(*query.named_selects).from_("cte").with_("cte", as_=query), + query_=ParsableSql( + sql=exp.select(*query.named_selects) + .from_("cte") + .with_("cte", as_=query) + .sql(dialect=model.dialect) + ), ) context.plan(auto_apply=True) diff --git a/tests/integrations/github/cicd/test_integration.py b/tests/integrations/github/cicd/test_integration.py index f78419889d..ce357f6d36 100644 --- a/tests/integrations/github/cicd/test_integration.py +++ b/tests/integrations/github/cicd/test_integration.py @@ -16,6 +16,7 @@ from sqlmesh.core.config import CategorizerConfig, Config, ModelDefaultsConfig, LinterConfig from sqlmesh.core.engine_adapter.shared import DataObject from sqlmesh.core.user import User, UserRole +from sqlmesh.core.model.common import ParsableSql from sqlmesh.integrations.github.cicd import command from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig, MergeMethod from sqlmesh.integrations.github.cicd.controller import ( @@ -249,8 +250,10 @@ def test_merge_pr_has_non_breaking_change( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -458,8 +461,10 @@ def test_merge_pr_has_non_breaking_change_diff_start( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -666,8 +671,10 @@ def test_merge_pr_has_non_breaking_change_no_categorization( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -983,8 +990,10 @@ def test_no_merge_since_no_deploy_signal( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -1183,8 +1192,10 @@ def test_no_merge_since_no_deploy_signal_no_approvers_defined( controller._context.users = [User(username="test", github_username="test_github", roles=[])] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -1357,8 +1368,10 @@ def test_deploy_comment_pre_categorized( controller._context.users = [User(username="test", github_username="test_github", roles=[])] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) # Manually categorize the change as non-breaking and don't backfill anything controller._context.plan( @@ -1557,8 +1570,12 @@ def test_error_msg_when_applying_plan_with_bug( ] # Make an error by adding a column that doesn't exist model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("non_existing_col", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql( + sql=model.query.select(exp.alias_("non_existing_col", "new_col")).sql(model.dialect) + ), + ) github_output_file = tmp_path / "github_output.txt" @@ -1716,8 +1733,10 @@ def test_overlapping_changes_models( # These changes have shared children and this ensures we don't repeat the children in the output # Make a non-breaking change model = controller._context.get_model("sushi.customers").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) # Make a breaking change model = controller._context.get_model("sushi.waiter_names").copy() @@ -2283,8 +2302,10 @@ def test_has_required_approval_but_not_base_branch( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt"