Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.

def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
"""Returns a blueprint variable value."""
return self._blueprint_variables.get(var_name.lower(), default)
return self._blueprint_variables.get(var_name) or self._blueprint_variables.get(
var_name.lower(), default
)

def with_variables(
self,
Expand Down
14 changes: 11 additions & 3 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,20 @@ def evaluate_macros(
changed = True
variables = self.variables

if node.name not in self.locals and node.name.lower() not in variables:
if (
node.name not in self.locals
and node.name.lower() not in variables
and node.name not in variables
):
if not isinstance(node.parent, StagedFilePath):
raise SQLMeshError(f"Macro variable '{node.name}' is undefined.")

return node

# Precedence order is locals (e.g. @DEF) > blueprint variables > config variables
value = self.locals.get(node.name, variables.get(node.name.lower()))
value = self.locals.get(
node.name, variables.get(node.name, variables.get(node.name.lower()))
)
if isinstance(value, list):
return exp.convert(
tuple(
Expand Down Expand Up @@ -532,7 +538,9 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.

def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
"""Returns the value of the specified blueprint variable, or the default value if it doesn't exist."""
return (self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}).get(var_name.lower(), default)
return (self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}).get(var_name) or (
self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}
).get(var_name.lower(), default)

@property
def variables(self) -> t.Dict[str, t.Any]:
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/core/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:

@staticmethod
def blueprint_var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
return (blueprint_variables or {}).get(var_name.lower(), default)
return (blueprint_variables or {}).get(var_name) or (blueprint_variables or {}).get(
var_name.lower(), default
)

env = prepare_env(python_env)
local_env = dict.fromkeys(("context", "evaluator"), VariableResolutionContext)
Expand Down
52 changes: 52 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9463,6 +9463,58 @@ def test_blueprinting_with_quotes(tmp_path: Path) -> None:
assert t.cast(exp.Query, m2.render_query()).sql() == '''SELECT 'c d' AS "c1", "c d" AS "c2"'''


def test_blueprinting_with_uppercase_blueprint_names(tmp_path: Path) -> None:
init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY)

template_with_uppercase_vars = tmp_path / "models/template_with_uppercase_vars.sql"
template_with_uppercase_vars.parent.mkdir(parents=True, exist_ok=True)
template_with_uppercase_vars.write_text(
"""
MODEL (
name @{Customer_Name}.my_table,
blueprints (
(Customer_Name := customer1, Field_A := 'value1', Field_B := 100),
(Customer_Name := customer2, Field_A := 'value2', Field_B := 200),
),
);

SELECT
@Customer_Name AS customer_name,
@Field_A AS field_a_macro,
@{Field_B} AS field_b_identifier,
@BLUEPRINT_VAR('Field_A') AS field_a_func,
@BLUEPRINT_VAR('Field_B') AS field_b_func_lower
"""
)

ctx = Context(
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path
)
assert len(ctx.models) == 2

m1 = ctx.get_model('"memory"."customer1"."my_table"', raise_if_missing=True)
m2 = ctx.get_model('"memory"."customer2"."my_table"', raise_if_missing=True)

# Verify that uppercase references in the query work correctly
query1 = t.cast(exp.Query, m1.render_query()).sql()
query2 = t.cast(exp.Query, m2.render_query()).sql()

assert '"customer1"' in query1
assert "'value1'" in query1
assert "100" in query1

assert '"customer2"' in query2
assert "'value2'" in query2
assert "200" in query2

# Verify exact query structure
expected_query1 = '''SELECT "customer1" AS "customer_name", 'value1' AS "field_a_macro", "100" AS "field_b_identifier", 'value1' AS "field_a_func", 100 AS "field_b_func_lower"'''
expected_query2 = '''SELECT "customer2" AS "customer_name", 'value2' AS "field_a_macro", "200" AS "field_b_identifier", 'value2' AS "field_a_func", 200 AS "field_b_func_lower"'''

assert query1 == expected_query1
assert query2 == expected_query2


def test_blueprint_variable_precedence_sql(tmp_path: Path, assert_exp_eq: t.Callable) -> None:
init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY)

Expand Down