Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sqlmesh/core/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,16 @@ def _add_variables_to_python_env(

variables = {k: v for k, v in (variables or {}).items() if k in used_variables}
if variables:
python_env[c.SQLMESH_VARS] = Executable.value(variables)
python_env[c.SQLMESH_VARS] = Executable.value(variables, sort_root_dict=True)

if blueprint_variables:
blueprint_variables = {
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
for k, v in blueprint_variables.items()
}
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(blueprint_variables)
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(
blueprint_variables, sort_root_dict=True
)

return python_env

Expand Down
35 changes: 16 additions & 19 deletions sqlmesh/migrations/v0085_deterministic_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import json
import logging
import typing as t
from dataclasses import dataclass

Expand All @@ -12,6 +13,12 @@
from sqlmesh.utils.migration import index_text_type, blob_text_type


logger = logging.getLogger(__name__)


KEYS_TO_MAKE_DETERMINISTIC = ["__sqlmesh__vars__", "__sqlmesh__blueprint__vars__"]


# Make sure `SqlValue` is defined so it can be used by `eval` call in the migration
@dataclass
class SqlValue:
Expand All @@ -20,25 +27,13 @@ class SqlValue:
sql: str


def _deterministic_repr(obj: t.Any) -> str:
"""
This is a copy of the function from utils.metaprogramming
"""

def _normalize_for_repr(o: t.Any) -> t.Any:
if isinstance(o, dict):
sorted_items = sorted(o.items(), key=lambda x: str(x[0]))
return {k: _normalize_for_repr(v) for k, v in sorted_items}
if isinstance(o, (list, tuple)):
# Recursively normalize nested structures
normalized = [_normalize_for_repr(item) for item in o]
return type(o)(normalized)
return o

def _dict_sort(obj: t.Any) -> str:
try:
return repr(_normalize_for_repr(obj))
if isinstance(obj, dict):
obj = dict(sorted(obj.items(), key=lambda x: str(x[0])))
except Exception:
return repr(obj)
logger.warning("Failed to sort non-recursive dict", exc_info=True)
return repr(obj)


def migrate(state_sync, **kwargs): # type: ignore
Expand Down Expand Up @@ -82,20 +77,22 @@ def migrate(state_sync, **kwargs): # type: ignore

if python_env:
for key, executable in python_env.items():
if key not in KEYS_TO_MAKE_DETERMINISTIC:
continue
if isinstance(executable, dict) and executable.get("kind") == "value":
old_payload = executable["payload"]
try:
# Try to parse the old payload and re-serialize it deterministically
parsed_value = eval(old_payload)
new_payload = _deterministic_repr(parsed_value)
new_payload = _dict_sort(parsed_value)

# Only update if the representation changed
if old_payload != new_payload:
executable["payload"] = new_payload
migration_needed = True
except Exception:
# If we still can't eval it, leave it as-is
pass
logger.warning("Exception trying to eval payload", exc_info=True)

new_snapshots.append(
{
Expand Down
82 changes: 82 additions & 0 deletions sqlmesh/migrations/v0086_check_deterministic_bug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import json
import logging

from sqlglot import exp

from sqlmesh.core.console import get_console


logger = logging.getLogger(__name__)
KEYS_TO_MAKE_DETERMINISTIC = ["__sqlmesh__vars__", "__sqlmesh__blueprint__vars__"]


def migrate(state_sync, **kwargs): # type: ignore
engine_adapter = state_sync.engine_adapter
schema = state_sync.schema
snapshots_table = "_snapshots"
versions_table = "_versions"
if schema:
snapshots_table = f"{schema}.{snapshots_table}"
versions_table = f"{schema}.{versions_table}"

result = engine_adapter.fetchone(
exp.select("schema_version").from_(versions_table), quote_identifiers=True
)
if not result:
# This must be the first migration, so we can skip the check since the project was not exposed to 85 migration bug
return
schema_version = result[0]
if schema_version < 85:
# The project was not exposed to the bugged 85 migration, so we can skip it.
return

warning = (
"SQLMesh detected that it may not be able to fully migrate the state database. This should not impact "
"the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` "
"command. Please run `sqlmesh diff prod` after the migration has completed, before making any new "
"changes. If any unexpected changes are reported, consider running a forward-only plan to apply these "
"changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. "
"See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n"
)

for (
name,
identifier,
version,
snapshot,
kind_name,
updated_ts,
unpaused_ts,
ttl_ms,
unrestorable,
) in engine_adapter.fetchall(
exp.select(
"name",
"identifier",
"version",
"snapshot",
"kind_name",
"updated_ts",
"unpaused_ts",
"ttl_ms",
"unrestorable",
).from_(snapshots_table),
quote_identifiers=True,
):
parsed_snapshot = json.loads(snapshot)
python_env = parsed_snapshot["node"].get("python_env")

if python_env:
for key, executable in python_env.items():
if (
key not in KEYS_TO_MAKE_DETERMINISTIC
and isinstance(executable, dict)
and executable.get("kind") == "value"
):
try:
parsed_value = eval(executable["payload"])
if isinstance(parsed_value, dict):
get_console().log_warning(warning)
return
except Exception:
logger.warning("Exception trying to eval payload", exc_info=True)
46 changes: 14 additions & 32 deletions sqlmesh/utils/metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import importlib
import inspect
import linecache
import logging
import os
import re
import sys
Expand All @@ -23,6 +24,9 @@
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.pydantic import PydanticModel

logger = logging.getLogger(__name__)


IGNORE_DECORATORS = {"macro", "model", "signal"}
SERIALIZABLE_CALLABLES = (type, types.FunctionType)
LITERALS = (Number, str, bytes, tuple, list, dict, set, bool)
Expand Down Expand Up @@ -424,10 +428,11 @@ def is_value(self) -> bool:
return self.kind == ExecutableKind.VALUE

@classmethod
def value(cls, v: t.Any, is_metadata: t.Optional[bool] = None) -> Executable:
return Executable(
payload=_deterministic_repr(v), kind=ExecutableKind.VALUE, is_metadata=is_metadata
)
def value(
cls, v: t.Any, is_metadata: t.Optional[bool] = None, sort_root_dict: bool = False
) -> Executable:
payload = _dict_sort(v) if sort_root_dict else repr(v)
return Executable(payload=payload, kind=ExecutableKind.VALUE, is_metadata=is_metadata)


def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable]:
Expand Down Expand Up @@ -635,36 +640,13 @@ def print_exception(
out.write(tb)


def _deterministic_repr(obj: t.Any) -> str:
"""Create a deterministic representation by ensuring consistent ordering before repr().

For dictionaries, ensures consistent key ordering to prevent non-deterministic
serialization that affects fingerprinting. Uses Python's native repr() logic
for all formatting to handle edge cases properly.

Note that this function assumes list/tuple order is significant and therefore does not sort them.

Args:
obj: The object to represent as a string.

Returns:
A deterministic string representation of the object.
"""

def _normalize_for_repr(o: t.Any) -> t.Any:
if isinstance(o, dict):
sorted_items = sorted(o.items(), key=lambda x: str(x[0]))
return {k: _normalize_for_repr(v) for k, v in sorted_items}
if isinstance(o, (list, tuple)):
# Recursively normalize nested structures
normalized = [_normalize_for_repr(item) for item in o]
return type(o)(normalized)
return o

def _dict_sort(obj: t.Any) -> str:
try:
return repr(_normalize_for_repr(obj))
if isinstance(obj, dict):
obj = dict(sorted(obj.items(), key=lambda x: str(x[0])))
except Exception:
return repr(obj)
logger.warning("Failed to sort non-recursive dict", exc_info=True)
return repr(obj)


def import_python_file(path: Path, relative_base: Path = Path()) -> types.ModuleType:
Expand Down
25 changes: 17 additions & 8 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6116,7 +6116,8 @@ def test_named_variable_macros() -> None:
)

assert model.python_env[c.SQLMESH_VARS] == Executable.value(
{c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"}
{c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"},
sort_root_dict=True,
)
assert (
model.render_query_or_raise().sql()
Expand All @@ -6142,7 +6143,8 @@ def test_variables_in_templates() -> None:
)

assert model.python_env[c.SQLMESH_VARS] == Executable.value(
{c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"}
{c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"},
sort_root_dict=True,
)
assert (
model.render_query_or_raise().sql()
Expand All @@ -6166,7 +6168,8 @@ def test_variables_in_templates() -> None:
)

assert model.python_env[c.SQLMESH_VARS] == Executable.value(
{c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"}
{c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"},
sort_root_dict=True,
)
assert (
model.render_query_or_raise().sql()
Expand Down Expand Up @@ -6305,7 +6308,8 @@ def test_variables_migrated_dbt_package_macro():
dialect="bigquery",
)
assert model.python_env[c.SQLMESH_VARS] == Executable.value(
{"test_var_a": "test_var_a_value", "__dbt_packages__.test.test_var_b": "test_var_b_value"}
{"test_var_a": "test_var_a_value", "__dbt_packages__.test.test_var_b": "test_var_b_value"},
sort_root_dict=True,
)
assert (
model.render_query().sql(dialect="bigquery")
Expand Down Expand Up @@ -6530,7 +6534,8 @@ def test_unrendered_macros_sql_model(mocker: MockerFixture) -> None:
"physical_var": "bla",
"virtual_var": "blb",
"session_var": "blc",
}
},
sort_root_dict=True,
)

assert "location1" in model.physical_properties
Expand Down Expand Up @@ -6617,7 +6622,8 @@ def model_with_macros(evaluator, **kwargs):
"physical_var": "bla",
"virtual_var": "blb",
"session_var": "blc",
}
},
sort_root_dict=True,
)
assert python_sql_model.enabled

Expand Down Expand Up @@ -10576,9 +10582,12 @@ def unimportant_testing_macro(evaluator, *projections):
)

assert m.python_env.get(c.SQLMESH_VARS) == Executable.value(
{"selector": "bla", "bla_variable": 1, "baz_variable": 2}
{"selector": "bla", "bla_variable": 1, "baz_variable": 2},
sort_root_dict=True,
)
assert m.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value(
{"selector": "baz"}, sort_root_dict=True
)
assert m.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value({"selector": "baz"})


def test_extract_schema_in_post_statement(tmp_path: Path) -> None:
Expand Down
Loading
Loading