diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 554638aec7..b58817950d 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -128,6 +128,17 @@ def _macro_str_replace(text: str) -> str: return f"self.template({text}, locals())" +class CaseInsensitiveMapping(t.Dict[str, t.Any]): + def __init__(self, data: t.Dict[str, t.Any]) -> None: + super().__init__(data) + + def __getitem__(self, key: str) -> t.Any: + return super().__getitem__(key.lower()) + + def get(self, key: str, default: t.Any = None, /) -> t.Any: + return super().get(key.lower(), default) + + class MacroDialect(Python): class Generator(Python.Generator): TRANSFORMS = { @@ -256,14 +267,18 @@ def evaluate_macros( changed = True variables = self.variables - if node.name not in self.locals and node.name.lower() not in variables: + # This makes all variables case-insensitive, e.g. @X is the same as @x. We do this + # for consistency, since `variables` and `blueprint_variables` are normalized. + var_name = node.name.lower() + + if var_name not in self.locals and var_name not in variables: if not isinstance(node.parent, StagedFilePath): raise SQLMeshError(f"Macro variable '{node.name}' is undefined.") return node # Precedence order is locals (e.g. @DEF) > blueprint variables > config variables - value = self.locals.get(node.name, variables.get(node.name.lower())) + value = self.locals.get(var_name, variables.get(var_name)) if isinstance(value, list): return exp.convert( tuple( @@ -313,11 +328,11 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str: """ # We try to convert all variables into sqlglot expressions because they're going to be converted # into strings; in sql we don't convert strings because that would result in adding quotes - mapping = { - k: convert_sql(v, self.dialect) + base_mapping = { + k.lower(): convert_sql(v, self.dialect) for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items()) } - return MacroStrTemplate(str(text)).safe_substitute(mapping) + return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping)) def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None: if isinstance(node, MacroDef): @@ -327,7 +342,9 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | args[0] if len(args) == 1 else exp.Tuple(expressions=list(args)) ) else: - self.locals[node.name] = self.transform(node.expression) + # Make variables defined through `@DEF` case-insensitive + self.locals[node.name.lower()] = self.transform(node.expression) + return node if isinstance(node, (MacroSQL, MacroStrReplace)): @@ -630,7 +647,7 @@ def substitute( ) -> exp.Expression | t.List[exp.Expression] | None: if isinstance(node, (exp.Identifier, exp.Var)): if not isinstance(node.parent, exp.Column): - name = node.name + name = node.name.lower() if name in args: return args[name].copy() if name in evaluator.locals: @@ -663,7 +680,7 @@ def substitute( return expressions, lambda args: func.this.transform( substitute, { - expression.name: arg + expression.name.lower(): arg for expression, arg in zip( func.expressions, args.expressions if isinstance(args, exp.Tuple) else [args] ) diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index 77d8fb84ae..fb10f64b27 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -292,6 +292,16 @@ def test_ast_correctness(macro_evaluator): "SELECT 'a' + a_z + 'c' + c_a, 'b' + b_z + 'c' + c_b", {"y": "c"}, ), + ( + """select @each(['a'], x -> @X)""", + "SELECT 'a'", + {}, + ), + ( + """select @each(['a'], X -> @x)""", + "SELECT 'a'", + {}, + ), ( '"is_@{x}"', '"is_b"', @@ -1112,7 +1122,9 @@ def test_macro_with_spaces(): for sql, expected in ( ("@x", '"a b"'), + ("@X", '"a b"'), ("@{x}", '"a b"'), + ("@{X}", '"a b"'), ("a_@x", '"a_a b"'), ("a.@x", 'a."a b"'), ("@y", "'a b'"), @@ -1121,6 +1133,7 @@ def test_macro_with_spaces(): ("a.@{y}", 'a."a b"'), ("@z", 'a."b c"'), ("d.@z", 'd.a."b c"'), + ("@'test_@{X}_suffix'", "'test_a b_suffix'"), ): assert evaluator.transform(parse_one(sql)).sql() == expected diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 1511e37c53..ebe4d11a20 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -9377,9 +9377,9 @@ def test_model_blueprinting(tmp_path: Path) -> None: model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) - blueprint_sql = tmp_path / "macros" / "identity_macro.py" - blueprint_sql.parent.mkdir(parents=True, exist_ok=True) - blueprint_sql.write_text( + identity_macro = tmp_path / "macros" / "identity_macro.py" + identity_macro.parent.mkdir(parents=True, exist_ok=True) + identity_macro.write_text( """from sqlmesh import macro @macro() @@ -11623,3 +11623,40 @@ def test_use_original_sql(): assert model.query_.sql == "SELECT 1 AS one, 2 AS two" assert model.pre_statements_[0].sql == "CREATE TABLE pre (a INT)" assert model.post_statements_[0].sql == "CREATE TABLE post (b INT)" + + +def test_case_sensitive_macro_locals(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + db_path = str(tmp_path / "db.db") + db_connection = DuckDBConnectionConfig(database=db_path) + + config = Config( + gateways={"gw": GatewayConfig(connection=db_connection)}, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + macro_file = tmp_path / "macros" / "some_macro_with_globals.py" + macro_file.parent.mkdir(parents=True, exist_ok=True) + macro_file.write_text( + """from sqlmesh import macro + +x = 1 +X = 2 + +@macro() +def my_macro(evaluator): + assert evaluator.locals.get("x") == 1 + assert evaluator.locals.get("X") == 2 + + return x + X +""" + ) + test_model = tmp_path / "models" / "test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text("MODEL (name test_model, kind FULL); SELECT @my_macro() AS c") + + context = Context(paths=tmp_path, config=config) + model = context.get_model("test_model", raise_if_missing=True) + + assert model.render_query_or_raise().sql() == 'SELECT 3 AS "c"'