@@ -409,13 +409,12 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
409409 return None
410410
411411 name = key .name .lower ()
412- if name == "when_matched" :
413- value : t .Optional [t .Union [exp .Expression , t .List [exp .Expression ]]] = (
414- self ._parse_when_matched () # type: ignore
415- )
416- elif name == "time_data_type" :
412+ if name == "time_data_type" :
417413 # TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic
418414 value = self ._parse_types (schema = True )
415+ elif name == "when_matched" :
416+ # Parentheses around the WHEN clauses can be used to disambiguate them from other properties
417+ value = self ._parse_wrapped (self ._parse_when_matched , optional = True )
419418 elif self ._match (TokenType .L_PAREN ):
420419 value = self .expression (exp .Tuple , expressions = self ._parse_csv (self ._parse_equality ))
421420 self ._match_r_paren ()
@@ -605,15 +604,11 @@ def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
605604 size = len (expressions )
606605
607606 for i , prop in enumerate (expressions ):
608- value = prop .args .get ("value" )
609- if prop .name == "when_matched" and isinstance (value , list ):
610- output_value = ", " .join (self .sql (v ) for v in value )
611- else :
612- output_value = self .sql (prop , "value" )
613- sql = self .indent (f"{ prop .name } { output_value } " )
607+ sql = self .indent (f"{ prop .name } { self .sql (prop , 'value' )} " )
614608
615609 if i < size - 1 :
616610 sql += ","
611+
617612 props .append (self .maybe_comment (sql , expression = prop ))
618613
619614 return "\n " .join (props )
@@ -648,6 +643,15 @@ def _macro_func_sql(self: Generator, expression: MacroFunc) -> str:
648643 return self .maybe_comment (sql , expression )
649644
650645
646+ def _whens_sql (self : Generator , expression : exp .Whens ) -> str :
647+ if isinstance (expression .parent , exp .Merge ):
648+ return self .whens_sql (expression )
649+
650+ # If the `WHEN` clauses aren't part of a MERGE statement (e.g. they
651+ # appear in the `MODEL` DDL), then we will wrap them with parentheses.
652+ return self .wrap (self .expressions (expression , sep = " " , indent = False ))
653+
654+
651655def _override (klass : t .Type [Tokenizer | Parser ], func : t .Callable ) -> None :
652656 name = func .__name__
653657 setattr (klass , f"_{ name } " , getattr (klass , name ))
@@ -901,6 +905,7 @@ def extend_sqlglot() -> None:
901905 ModelKind : _model_kind_sql ,
902906 PythonCode : lambda self , e : self .expressions (e , sep = "\n " , indent = False ),
903907 StagedFilePath : lambda self , e : self .table_sql (e ),
908+ exp .Whens : _whens_sql ,
904909 }
905910 )
906911
0 commit comments