diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 88e09f4916..1a42480c13 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -419,6 +419,20 @@ def _parse_limit( return macro +def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression]: + wrapped = self._match(TokenType.L_PAREN, advance=False) + + # The base _parse_value method always constructs a Tuple instance. This is problematic when + # generating values with a macro function, because it's impossible to tell whether the user's + # intention was to construct a row or a column with the VALUES expression. To avoid this, we + # amend the AST such that the Tuple is replaced by the macro function call itself. + expr = self.__parse_value() # type: ignore + if expr and not wrapped and isinstance(seq_get(expr.expressions, 0), MacroFunc): + return expr.expressions[0] + + return expr + + def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expression]: return _parse_macro(self) if self._match(TokenType.PARAMETER) else parser() @@ -1063,6 +1077,7 @@ def extend_sqlglot() -> None: _override(Parser, _parse_with) _override(Parser, _parse_having) _override(Parser, _parse_limit) + _override(Parser, _parse_value) _override(Parser, _parse_lambda) _override(Parser, _parse_types) _override(Parser, _parse_if) diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index c235430a69..f1beeeb3b5 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -575,6 +575,26 @@ def test_ast_correctness(macro_evaluator): "SELECT 3", {}, ), + ( + "SELECT * FROM (VALUES @EACH([1, 2, 3], v -> (v)) ) AS v", + "SELECT * FROM (VALUES (1), (2), (3)) AS v", + {}, + ), + ( + "SELECT * FROM (VALUES (@EACH([1, 2, 3], v -> (v))) ) AS v", + "SELECT * FROM (VALUES ((1), (2), (3))) AS v", + {}, + ), + ( + "SELECT * FROM (VALUES @EACH([1, 2, 3], v -> (v, @EVAL(@v + 1))) ) AS v", + "SELECT * FROM (VALUES (1, 2), (2, 3), (3, 4)) AS v", + {}, + ), + ( + "SELECT * FROM (VALUES (@EACH([1, 2, 3], v -> (v, @EVAL(@v + 1)))) ) AS v", + "SELECT * FROM (VALUES ((1, 2), (2, 3), (3, 4))) AS v", + {}, + ), ], ) def test_macro_functions(macro_evaluator: MacroEvaluator, assert_exp_eq, sql, expected, args):