Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions sqlmesh/core/audit/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
bool_validator,
default_catalog_validator,
depends_on_validator,
expression_validator,
sort_python_env,
sorted_python_env_payloads,
)
from sqlmesh.core.model.common import make_python_env, single_value_or_tuple
from sqlmesh.core.model.common import make_python_env, single_value_or_tuple, ParsableSql
from sqlmesh.core.node import _Node
from sqlmesh.core.renderer import QueryRenderer
from sqlmesh.utils.date import TimeLike
Expand Down Expand Up @@ -67,15 +66,26 @@ class AuditMixin(AuditCommonMetaMixin):
jinja_macros: A registry of jinja macros to use when rendering the audit query.
"""

query: t.Union[exp.Query, d.JinjaQuery]
query_: ParsableSql
defaults: t.Dict[str, exp.Expression]
expressions_: t.Optional[t.List[exp.Expression]]
expressions_: t.Optional[t.List[ParsableSql]]
jinja_macros: JinjaMacroRegistry
formatting: t.Optional[bool]

@property
def query(self) -> t.Union[exp.Query, d.JinjaQuery]:
return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this called multiple times? should we cache this somehow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, you use a class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's cached inside ParsableQuery


@property
def expressions(self) -> t.List[exp.Expression]:
return self.expressions_ or []
if not self.expressions_:
return []
result = []
for e in self.expressions_:
parsed = e.parse(self.dialect)
if not isinstance(parsed, exp.Semicolon):
result.append(parsed)
return result

@property
def macro_definitions(self) -> t.List[d.MacroDef]:
Expand Down Expand Up @@ -122,16 +132,16 @@ class ModelAudit(PydanticModel, AuditMixin, frozen=True):
skip: bool = False
blocking: bool = True
standalone: t.Literal[False] = False
query: t.Union[exp.Query, d.JinjaQuery]
query_: ParsableSql = Field(alias="query")
defaults: t.Dict[str, exp.Expression] = {}
expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions")
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
formatting: t.Optional[bool] = Field(default=None, exclude=True)

_path: t.Optional[Path] = None

# Validators
_query_validator = expression_validator
_query_validator = ParsableSql.validator()
_bool_validator = bool_validator
_string_validator = audit_string_validator
_map_validator = audit_map_validator
Expand All @@ -153,9 +163,9 @@ class StandaloneAudit(_Node, AuditMixin):
skip: bool = False
blocking: bool = False
standalone: t.Literal[True] = True
query: t.Union[exp.Query, d.JinjaQuery]
query_: ParsableSql = Field(alias="query")
defaults: t.Dict[str, exp.Expression] = {}
expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions")
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
default_catalog: t.Optional[str] = None
depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on")
Expand All @@ -165,7 +175,7 @@ class StandaloneAudit(_Node, AuditMixin):
source_type: t.Literal["audit"] = "audit"

# Validators
_query_validator = expression_validator
_query_validator = ParsableSql.validator()
_bool_validator = bool_validator
_string_validator = audit_string_validator
_map_validator = audit_map_validator
Expand Down Expand Up @@ -276,8 +286,8 @@ def metadata_hash(self) -> str:
self.cron_tz.key if self.cron_tz else None,
]

query = self.render_audit_query() or self.query
data.append(gen(query))
data.append(self.query_.sql)
data.extend([e.sql for e in self.expressions_ or []])
self._metadata_hash = hash_data(data)
return self._metadata_hash

Expand Down Expand Up @@ -461,11 +471,17 @@ def load_audit(
if project is not None:
extra_kwargs["project"] = project

dialect = meta_fields.pop("dialect", dialect)
dialect = meta_fields.pop("dialect", dialect) or ""

parsable_query = ParsableSql.from_parsed_expression(query, dialect, use_meta_sql=True)
parsable_statements = [
ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=True) for s in statements
]

try:
audit = audit_class(
query=query,
expressions=statements,
query=parsable_query,
expressions=parsable_statements,
dialect=dialect,
**extra_kwargs,
**meta_fields,
Expand Down
9 changes: 3 additions & 6 deletions sqlmesh/core/context_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def directly_modified(self, name: str) -> bool:
return False

current, previous = self.modified_snapshots[name]
return current.fingerprint.data_hash != previous.fingerprint.data_hash
return current.is_directly_modified(previous)

def indirectly_modified(self, name: str) -> bool:
"""Returns whether or not a node was indirectly modified in this context.
Expand All @@ -451,10 +451,7 @@ def indirectly_modified(self, name: str) -> bool:
return False

current, previous = self.modified_snapshots[name]
return (
current.fingerprint.data_hash == previous.fingerprint.data_hash
and current.fingerprint.parent_data_hash != previous.fingerprint.parent_data_hash
)
return current.is_indirectly_modified(previous)

def metadata_updated(self, name: str) -> bool:
"""Returns whether or not the given node's metadata has been updated.
Expand All @@ -470,7 +467,7 @@ def metadata_updated(self, name: str) -> bool:
return False

current, previous = self.modified_snapshots[name]
return current.fingerprint.metadata_hash != previous.fingerprint.metadata_hash
return current.is_metadata_updated(previous)

def text_diff(self, name: str) -> str:
"""Finds the difference of a node between the current and remote environment.
Expand Down
69 changes: 63 additions & 6 deletions sqlmesh/core/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
prepare_env,
serialize_env,
)
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, get_dialect

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
Expand Down Expand Up @@ -616,11 +616,6 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any:


expression_validator: t.Callable = field_validator(
"query",
"expressions_",
"pre_statements_",
"post_statements_",
"on_virtual_update_",
"unique_key",
mode="before",
check_fields=False,
Expand Down Expand Up @@ -663,3 +658,65 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any:
mode="before",
check_fields=False,
)(depends_on)


class ParsableSql(PydanticModel):
sql: str

_parsed: t.Optional[exp.Expression] = None
_parsed_dialect: t.Optional[str] = None

def parse(self, dialect: str) -> exp.Expression:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the dialect ever change for a model outside of a test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't, why? I'd rather implement this correctly in case circumstances change.

if self._parsed is None or self._parsed_dialect != dialect:
self._parsed = d.parse_one(self.sql, dialect=dialect)
self._parsed_dialect = dialect
return self._parsed

@classmethod
def from_parsed_expression(
cls, parsed_expression: exp.Expression, dialect: str, use_meta_sql: bool = False
) -> ParsableSql:
sql = (
parsed_expression.meta.get("sql") or parsed_expression.sql(dialect=dialect)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the situation where we wouldn't want to use meta sql?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when I'm using a custom loader and do create_sql_model directly with a query. I don't think we can trust the correctness of the meta sql in that case.

if use_meta_sql
else parsed_expression.sql(dialect=dialect)
)
result = cls(sql=sql)
result._parsed = parsed_expression
result._parsed_dialect = dialect
return result

@classmethod
def validator(cls) -> classmethod:
def _validate_parsable_sql(
v: t.Any, info: ValidationInfo
) -> t.Optional[t.Union[ParsableSql, t.List[ParsableSql]]]:
if v is None:
return v
if isinstance(v, str):
return ParsableSql(sql=v)
if isinstance(v, exp.Expression):
return ParsableSql.from_parsed_expression(
v, get_dialect(info.data), use_meta_sql=False
)
if isinstance(v, list):
dialect = get_dialect(info.data)
return [
ParsableSql(sql=s)
if isinstance(s, str)
else ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=False)
if isinstance(s, exp.Expression)
else ParsableSql.parse_obj(s)
for s in v
]
return ParsableSql.parse_obj(v)

return field_validator(
"query_",
"expressions_",
"pre_statements_",
"post_statements_",
"on_virtual_update_",
mode="before",
check_fields=False,
)(_validate_parsable_sql)
Loading