diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 704f3e02fe..11ddc8234b 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -157,7 +157,7 @@ def _add_variables_to_python_env( if blueprint_variables: blueprint_variables = { - k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v + k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v for k, v in blueprint_variables.items() } python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value( diff --git a/sqlmesh/migrations/v0087_normalize_blueprint_variables.py b/sqlmesh/migrations/v0087_normalize_blueprint_variables.py new file mode 100644 index 0000000000..8878bc8019 --- /dev/null +++ b/sqlmesh/migrations/v0087_normalize_blueprint_variables.py @@ -0,0 +1,138 @@ +""" +Normalizes blueprint variables, so Customer_Field is stored as customer_field in the `python_env`: + +MODEL ( + ... + blueprints ( + Customer_Field := 1 + ) +); + +SELECT + @customer_field AS col +""" + +import json +import logging +from dataclasses import dataclass + +from sqlglot import exp +from sqlmesh.core.console import get_console +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +logger = logging.getLogger(__name__) + + +SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__" + + +# Make sure `SqlValue` is defined so it can be used by `eval` call in the migration +@dataclass +class SqlValue: + """A SQL string representing a generated SQLGlot AST.""" + + sql: str + + +def migrate(state_sync, **kwargs): # type: ignore + import pandas as pd + + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + migration_needed = False + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + python_env = node.get("python_env") or {} + + migrate_snapshot = False + + if blueprint_vars_executable := python_env.get(SQLMESH_BLUEPRINT_VARS): + blueprint_vars = eval(blueprint_vars_executable["payload"]) + + for var, value in dict(blueprint_vars).items(): + lowercase_var = var.lower() + if var != lowercase_var: + if lowercase_var in blueprint_vars: + get_console().log_warning( + "SQLMesh is unable to fully migrate the state database, because the " + f"model '{node['name']}' contains two blueprint variables ('{var}' and " + f"'{lowercase_var}') that resolve to the same value ('{lowercase_var}'). " + "This may result in unexpected changes being reported by the next " + "`sqlmesh plan` command. If this happens, consider renaming either variable, " + "so that the lowercase version of their names are different." + ) + else: + del blueprint_vars[var] + blueprint_vars[lowercase_var] = value + migrate_snapshot = True + + if migrate_snapshot: + migration_needed = True + blueprint_vars_executable["payload"] = repr(blueprint_vars) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if migration_needed and new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build("text"), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 4c2f30e2f7..14c29165b7 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -9396,14 +9396,14 @@ def entrypoint(evaluator): MODEL ( name @{customer}.my_table, blueprints ( - (customer := customer1, customer_field := 'bar'), - (customer := customer2, customer_field := qux), + (customer := customer1, Customer_Field := 'bar'), + (customer := customer2, Customer_Field := qux), ), kind FULL ); SELECT - @customer_field AS foo, + @customer_FIELD AS foo, @{customer_field} AS foo2, @BLUEPRINT_VAR('customer_field') AS foo3, FROM @{customer}.my_source