Skip to content

Commit 3789555

Browse files
committed
Fix!: normalize blueprint variables
1 parent 3ee4e9f commit 3789555

File tree

3 files changed

+138
-4
lines changed

3 files changed

+138
-4
lines changed

sqlmesh/core/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _add_variables_to_python_env(
157157

158158
if blueprint_variables:
159159
blueprint_variables = {
160-
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
160+
k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
161161
for k, v in blueprint_variables.items()
162162
}
163163
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Normalizes blueprint variables, so Customer_Field is stored as customer_field in the `python_env`:
3+
4+
MODEL (
5+
...
6+
blueprints (
7+
Customer_Field := 1
8+
)
9+
);
10+
11+
SELECT
12+
@customer_field AS col
13+
"""
14+
15+
import json
16+
import logging
17+
from dataclasses import dataclass
18+
19+
from sqlglot import exp
20+
from sqlmesh.utils.migration import index_text_type, blob_text_type
21+
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__"
27+
28+
29+
# Make sure `SqlValue` is defined so it can be used by `eval` call in the migration
30+
@dataclass
31+
class SqlValue:
32+
"""A SQL string representing a generated SQLGlot AST."""
33+
34+
sql: str
35+
36+
37+
def migrate(state_sync, **kwargs): # type: ignore
38+
import pandas as pd
39+
40+
engine_adapter = state_sync.engine_adapter
41+
schema = state_sync.schema
42+
snapshots_table = "_snapshots"
43+
if schema:
44+
snapshots_table = f"{schema}.{snapshots_table}"
45+
46+
migration_needed = False
47+
new_snapshots = []
48+
49+
for (
50+
name,
51+
identifier,
52+
version,
53+
snapshot,
54+
kind_name,
55+
updated_ts,
56+
unpaused_ts,
57+
ttl_ms,
58+
unrestorable,
59+
) in engine_adapter.fetchall(
60+
exp.select(
61+
"name",
62+
"identifier",
63+
"version",
64+
"snapshot",
65+
"kind_name",
66+
"updated_ts",
67+
"unpaused_ts",
68+
"ttl_ms",
69+
"unrestorable",
70+
).from_(snapshots_table),
71+
quote_identifiers=True,
72+
):
73+
parsed_snapshot = json.loads(snapshot)
74+
node = parsed_snapshot["node"]
75+
python_env = node.get("python_env") or {}
76+
77+
# Intentionally checking for falsey value here, since that accounts for empty dicts and None
78+
if blueprint_vars_executable := python_env.get(SQLMESH_BLUEPRINT_VARS):
79+
blueprint_vars = eval(blueprint_vars_executable["payload"])
80+
81+
for var, value in dict(blueprint_vars).items():
82+
lowercase_var = var.lower()
83+
if var != lowercase_var:
84+
# Ensures that we crash instead of overwriting snapshot payloads incorrectly
85+
assert lowercase_var not in blueprint_vars, (
86+
"SQLMesh could not migrate the state database successfully, because it detected "
87+
f"two different blueprint variable names ('{var}' and '{lowercase_var}') that resolve "
88+
f"to the same name ('{lowercase_var}') for model '{node['name']}'. Downgrade the local "
89+
"SQLMesh version to the previously-installed one, rename either of these variables, "
90+
"apply the corresponding plan and try again."
91+
)
92+
93+
del blueprint_vars[var]
94+
blueprint_vars[lowercase_var] = value
95+
migration_needed = True
96+
97+
if migration_needed:
98+
blueprint_vars_executable["payload"] = repr(blueprint_vars)
99+
100+
new_snapshots.append(
101+
{
102+
"name": name,
103+
"identifier": identifier,
104+
"version": version,
105+
"snapshot": json.dumps(parsed_snapshot),
106+
"kind_name": kind_name,
107+
"updated_ts": updated_ts,
108+
"unpaused_ts": unpaused_ts,
109+
"ttl_ms": ttl_ms,
110+
"unrestorable": unrestorable,
111+
}
112+
)
113+
114+
if migration_needed and new_snapshots:
115+
engine_adapter.delete_from(snapshots_table, "TRUE")
116+
117+
index_type = index_text_type(engine_adapter.dialect)
118+
blob_type = blob_text_type(engine_adapter.dialect)
119+
120+
engine_adapter.insert_append(
121+
snapshots_table,
122+
pd.DataFrame(new_snapshots),
123+
columns_to_types={
124+
"name": exp.DataType.build(index_type),
125+
"identifier": exp.DataType.build(index_type),
126+
"version": exp.DataType.build(index_type),
127+
"snapshot": exp.DataType.build(blob_type),
128+
"kind_name": exp.DataType.build("text"),
129+
"updated_ts": exp.DataType.build("bigint"),
130+
"unpaused_ts": exp.DataType.build("bigint"),
131+
"ttl_ms": exp.DataType.build("bigint"),
132+
"unrestorable": exp.DataType.build("boolean"),
133+
},
134+
)

tests/core/test_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9364,14 +9364,14 @@ def entrypoint(evaluator):
93649364
MODEL (
93659365
name @{customer}.my_table,
93669366
blueprints (
9367-
(customer := customer1, customer_field := 'bar'),
9368-
(customer := customer2, customer_field := qux),
9367+
(customer := customer1, Customer_Field := 'bar'),
9368+
(customer := customer2, Customer_Field := qux),
93699369
),
93709370
kind FULL
93719371
);
93729372
93739373
SELECT
9374-
@customer_field AS foo,
9374+
@customer_FIELD AS foo,
93759375
@{customer_field} AS foo2,
93769376
@BLUEPRINT_VAR('customer_field') AS foo3,
93779377
FROM @{customer}.my_source

0 commit comments

Comments
 (0)