Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ def _macro_str_replace(text: str) -> str:
return f"self.template({text}, locals())"


class CaseInsensitiveMapping(t.Dict[str, t.Any]):
def __init__(self, data: t.Dict[str, t.Any]) -> None:
super().__init__(data)

def __getitem__(self, key: str) -> t.Any:
return super().__getitem__(key.lower())

def get(self, key: str, default: t.Any = None, /) -> t.Any:
return super().get(key.lower(), default)


class MacroDialect(Python):
class Generator(Python.Generator):
TRANSFORMS = {
Expand Down Expand Up @@ -256,14 +267,18 @@ def evaluate_macros(
changed = True
variables = self.variables

if node.name not in self.locals and node.name.lower() not in variables:
# This makes all variables case-insensitive, e.g. @X is the same as @x. We do this
# for consistency, since `variables` and `blueprint_variables` are normalized.
var_name = node.name.lower()

if var_name not in self.locals and var_name not in variables:
Comment on lines -259 to +274
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has a subtle side-effect: today, despite being counter-intuitive, one can actually reference a variable defined in Python within a model:

# model.sql
MODEL (name test, kind full);
SELECT @x_plus_one() AS xp1, @X AS x

# macros.py
from sqlmesh import macro

X = 1

@macro()
def x_plus_one(evaluator):
    return X + 1

The above configuration is correct and the model renders into:

SELECT
  2 AS "xp1",
  1 AS "x"

After this PR, the reference @X will be invalid because we'll look up the lowercase "x" in locals, which won't match the Python uppercase key "X", resulting in:

Macro variable 'X' is undefined.

I think this is fine. I doubt anyone's relying on Python variables being exposed to SQL models today. I don't think we've even documented this anywhere, it's just a weird side-effect of populating locals with python_env items.

if not isinstance(node.parent, StagedFilePath):
raise SQLMeshError(f"Macro variable '{node.name}' is undefined.")

return node

# Precedence order is locals (e.g. @DEF) > blueprint variables > config variables
value = self.locals.get(node.name, variables.get(node.name.lower()))
value = self.locals.get(var_name, variables.get(var_name))
if isinstance(value, list):
return exp.convert(
tuple(
Expand Down Expand Up @@ -313,11 +328,11 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
"""
# We try to convert all variables into sqlglot expressions because they're going to be converted
# into strings; in sql we don't convert strings because that would result in adding quotes
mapping = {
k: convert_sql(v, self.dialect)
base_mapping = {
k.lower(): convert_sql(v, self.dialect)
for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items())
}
return MacroStrTemplate(str(text)).safe_substitute(mapping)
return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping))
Comment on lines -316 to +335
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is to make variable lookups in SQL case-insensitive; the template method is only used for substituting variables in SQL, as far as I can tell.


def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
if isinstance(node, MacroDef):
Expand All @@ -327,7 +342,9 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
args[0] if len(args) == 1 else exp.Tuple(expressions=list(args))
)
else:
self.locals[node.name] = self.transform(node.expression)
# Make variables defined through `@DEF` case-insensitive
self.locals[node.name.lower()] = self.transform(node.expression)
Comment on lines -330 to +346
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this should just affect SQL models that have @DEF– now the defined variables are automatically stored as lowercase.


return node

if isinstance(node, (MacroSQL, MacroStrReplace)):
Expand Down Expand Up @@ -630,7 +647,7 @@ def substitute(
) -> exp.Expression | t.List[exp.Expression] | None:
if isinstance(node, (exp.Identifier, exp.Var)):
if not isinstance(node.parent, exp.Column):
name = node.name
name = node.name.lower()
if name in args:
return args[name].copy()
if name in evaluator.locals:
Expand Down Expand Up @@ -663,7 +680,7 @@ def substitute(
return expressions, lambda args: func.this.transform(
substitute,
{
expression.name: arg
expression.name.lower(): arg
for expression, arg in zip(
func.expressions, args.expressions if isinstance(args, exp.Tuple) else [args]
)
Expand Down
13 changes: 13 additions & 0 deletions tests/core/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,16 @@ def test_ast_correctness(macro_evaluator):
"SELECT 'a' + a_z + 'c' + c_a, 'b' + b_z + 'c' + c_b",
{"y": "c"},
),
(
"""select @each(['a'], x -> @X)""",
"SELECT 'a'",
{},
),
(
"""select @each(['a'], X -> @x)""",
"SELECT 'a'",
{},
),
(
'"is_@{x}"',
'"is_b"',
Expand Down Expand Up @@ -1112,7 +1122,9 @@ def test_macro_with_spaces():

for sql, expected in (
("@x", '"a b"'),
("@X", '"a b"'),
("@{x}", '"a b"'),
("@{X}", '"a b"'),
("a_@x", '"a_a b"'),
("a.@x", 'a."a b"'),
("@y", "'a b'"),
Expand All @@ -1121,6 +1133,7 @@ def test_macro_with_spaces():
("a.@{y}", 'a."a b"'),
("@z", 'a."b c"'),
("d.@z", 'd.a."b c"'),
("@'test_@{X}_suffix'", "'test_a b_suffix'"),
):
assert evaluator.transform(parse_one(sql)).sql() == expected

Expand Down
43 changes: 40 additions & 3 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9377,9 +9377,9 @@ def test_model_blueprinting(tmp_path: Path) -> None:
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
)

blueprint_sql = tmp_path / "macros" / "identity_macro.py"
blueprint_sql.parent.mkdir(parents=True, exist_ok=True)
blueprint_sql.write_text(
identity_macro = tmp_path / "macros" / "identity_macro.py"
identity_macro.parent.mkdir(parents=True, exist_ok=True)
identity_macro.write_text(
"""from sqlmesh import macro

@macro()
Expand Down Expand Up @@ -11623,3 +11623,40 @@ def test_use_original_sql():
assert model.query_.sql == "SELECT 1 AS one, 2 AS two"
assert model.pre_statements_[0].sql == "CREATE TABLE pre (a INT)"
assert model.post_statements_[0].sql == "CREATE TABLE post (b INT)"


def test_case_sensitive_macro_locals(tmp_path: Path) -> None:
init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY)

db_path = str(tmp_path / "db.db")
db_connection = DuckDBConnectionConfig(database=db_path)

config = Config(
gateways={"gw": GatewayConfig(connection=db_connection)},
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
)

macro_file = tmp_path / "macros" / "some_macro_with_globals.py"
macro_file.parent.mkdir(parents=True, exist_ok=True)
macro_file.write_text(
"""from sqlmesh import macro

x = 1
X = 2

@macro()
def my_macro(evaluator):
assert evaluator.locals.get("x") == 1
assert evaluator.locals.get("X") == 2

return x + X
"""
)
test_model = tmp_path / "models" / "test_model.sql"
test_model.parent.mkdir(parents=True, exist_ok=True)
test_model.write_text("MODEL (name test_model, kind FULL); SELECT @my_macro() AS c")

context = Context(paths=tmp_path, config=config)
model = context.get_model("test_model", raise_if_missing=True)

assert model.render_query_or_raise().sql() == 'SELECT 3 AS "c"'