diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index ec5b2567f4..a72bf4605a 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -361,8 +361,37 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | return None if isinstance(result, (tuple, list)): - return [self.parse_one(item) for item in result if item is not None] - return self.parse_one(result) + result = [self.parse_one(item) for item in result if item is not None] + + if ( + len(result) == 1 + and isinstance(result[0], (exp.Array, exp.Tuple)) + and node.find_ancestor(MacroFunc) + ): + """ + if: + - the output of evaluating this node is being passed as an argument to another macro function + - and that output is something that _norm_var_arg_lambda() will unpack into varargs + > (a list containing a single item of type exp.Tuple/exp.Array) + then we will get inconsistent behaviour depending on if this node emits a list with a single item vs multiple items. + + In the first case, emitting a list containing a single array item will cause that array to get unpacked and its *members* passed to the calling macro + In the second case, emitting a list containing multiple array items will cause each item to get passed as-is to the calling macro + + To prevent this inconsistency, we wrap this node output in an exp.Array so that _norm_var_arg_lambda() can "unpack" that into the + actual argument we want to pass to the parent macro function + + Note we only do this for evaluation results that get passed as an argument to another macro, because when the final + result is given to something like SELECT, we still want that to be unpacked into a list of items like: + - SELECT ARRAY(1), ARRAY(2) + rather than a single item like: + - SELECT ARRAY(ARRAY(1), ARRAY(2)) + """ + result = [exp.Array(expressions=result)] + else: + result = self.parse_one(result) + + return result def eval_expression(self, node: t.Any) -> t.Any: """Converts a SQLGlot expression into executable Python code and evals it. diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index 0e3615d6c0..77d8fb84ae 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -363,11 +363,24 @@ def test_ast_correctness(macro_evaluator): "SELECT column LIKE a OR column LIKE b OR column LIKE c", {}, ), + ("SELECT @REDUCE([1], (x, y) -> x + y)", "SELECT 1", {}), + ("SELECT @REDUCE([1, 2], (x, y) -> x + y)", "SELECT 1 + 2", {}), + ("SELECT @REDUCE([[1]], (x, y) -> x + y)", "SELECT ARRAY(1)", {}), + ("SELECT @REDUCE([[1, 2]], (x, y) -> x + y)", "SELECT ARRAY(1, 2)", {}), ( """select @EACH([a, b, c], x -> column like x AS @SQL('@{x}_y', 'Identifier')), @x""", "SELECT column LIKE a AS a_y, column LIKE b AS b_y, column LIKE c AS c_y, '3'", {"x": "3"}, ), + ("SELECT @EACH([1], a -> [@a])", "SELECT ARRAY(1)", {}), + ("SELECT @EACH([1, 2], a -> [@a])", "SELECT ARRAY(1), ARRAY(2)", {}), + ("SELECT @REDUCE(@EACH([1], a -> [@a]), (x, y) -> x + y)", "SELECT ARRAY(1)", {}), + ( + "SELECT @REDUCE(@EACH([1, 2], a -> [@a]), (x, y) -> x + y)", + "SELECT ARRAY(1) + ARRAY(2)", + {}, + ), + ("SELECT @REDUCE([[1],[2]], (x, y) -> x + y)", "SELECT ARRAY(1) + ARRAY(2)", {}), ( """@WITH(@do_with) all_cities as (select * from city) select all_cities""", "WITH all_cities AS (SELECT * FROM city) SELECT all_cities",