Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,37 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
return None

if isinstance(result, (tuple, list)):
return [self.parse_one(item) for item in result if item is not None]
return self.parse_one(result)
result = [self.parse_one(item) for item in result if item is not None]

if (
len(result) == 1
and isinstance(result[0], (exp.Array, exp.Tuple))
and node.find_ancestor(MacroFunc)
):
"""
if:
- the output of evaluating this node is being passed as an argument to another macro function
- and that output is something that _norm_var_arg_lambda() will unpack into varargs
> (a list containing a single item of type exp.Tuple/exp.Array)
then we will get inconsistent behaviour depending on if this node emits a list with a single item vs multiple items.

In the first case, emitting a list containing a single array item will cause that array to get unpacked and its *members* passed to the calling macro
In the second case, emitting a list containing multiple array items will cause each item to get passed as-is to the calling macro

To prevent this inconsistency, we wrap this node output in an exp.Array so that _norm_var_arg_lambda() can "unpack" that into the
actual argument we want to pass to the parent macro function

Note we only do this for evaluation results that get passed as an argument to another macro, because when the final
result is given to something like SELECT, we still want that to be unpacked into a list of items like:
- SELECT ARRAY(1), ARRAY(2)
rather than a single item like:
- SELECT ARRAY(ARRAY(1), ARRAY(2))
"""
result = [exp.Array(expressions=result)]
else:
result = self.parse_one(result)

return result

def eval_expression(self, node: t.Any) -> t.Any:
"""Converts a SQLGlot expression into executable Python code and evals it.
Expand Down
13 changes: 13 additions & 0 deletions tests/core/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,24 @@ def test_ast_correctness(macro_evaluator):
"SELECT column LIKE a OR column LIKE b OR column LIKE c",
{},
),
("SELECT @REDUCE([1], (x, y) -> x + y)", "SELECT 1", {}),
("SELECT @REDUCE([1, 2], (x, y) -> x + y)", "SELECT 1 + 2", {}),
("SELECT @REDUCE([[1]], (x, y) -> x + y)", "SELECT ARRAY(1)", {}),
("SELECT @REDUCE([[1, 2]], (x, y) -> x + y)", "SELECT ARRAY(1, 2)", {}),
(
"""select @EACH([a, b, c], x -> column like x AS @SQL('@{x}_y', 'Identifier')), @x""",
"SELECT column LIKE a AS a_y, column LIKE b AS b_y, column LIKE c AS c_y, '3'",
{"x": "3"},
),
("SELECT @EACH([1], a -> [@a])", "SELECT ARRAY(1)", {}),
("SELECT @EACH([1, 2], a -> [@a])", "SELECT ARRAY(1), ARRAY(2)", {}),
("SELECT @REDUCE(@EACH([1], a -> [@a]), (x, y) -> x + y)", "SELECT ARRAY(1)", {}),
(
"SELECT @REDUCE(@EACH([1, 2], a -> [@a]), (x, y) -> x + y)",
"SELECT ARRAY(1) + ARRAY(2)",
{},
),
("SELECT @REDUCE([[1],[2]], (x, y) -> x + y)", "SELECT ARRAY(1) + ARRAY(2)", {}),
(
"""@WITH(@do_with) all_cities as (select * from city) select all_cities""",
"WITH all_cities AS (SELECT * FROM city) SELECT all_cities",
Expand Down