66
77from pydantic import Field
88from sqlglot import exp
9+ from sqlglot .optimizer .normalize_identifiers import normalize_identifiers
10+ from sqlglot .optimizer .qualify_columns import quote_identifiers
911from sqlglot .time import format_time
1012
1113from sqlmesh .core import dialect as d
3133 from typing_extensions import Annotated , Literal
3234
3335
36+ if t .TYPE_CHECKING :
37+ MODEL_KIND = t .TypeVar ("MODEL_KIND" , bound = "_ModelKind" )
38+
39+
3440class ModelKindMixin :
3541 @property
3642 def model_kind_name (self ) -> t .Optional [ModelKindName ]:
@@ -153,53 +159,76 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
153159
154160
155161class TimeColumn (PydanticModel ):
156- column : str
162+ column : exp . Expression
157163 format : t .Optional [str ] = None
158164
159165 @classmethod
160166 def validator (cls ) -> classmethod :
161- def _time_column_validator (v : t .Any ) -> TimeColumn :
167+ def _time_column_validator (v : t .Any , values : t .Any ) -> TimeColumn :
168+ values = values if isinstance (values , dict ) else values .data
169+ dialect = values .get ("dialect" )
162170 if isinstance (v , exp .Tuple ):
163- kwargs = {
164- key : v .expressions [i ].name
165- for i , key in enumerate (("column" , "format" )[: len (v .expressions )])
166- }
167- return TimeColumn (** kwargs )
168-
169- if isinstance (v , exp .Expression ):
170- return TimeColumn (column = v .name )
171+ column_expr = v .expressions [0 ]
172+ column = (
173+ exp .column (column_expr )
174+ if isinstance (column_expr , exp .Identifier )
175+ else column_expr
176+ )
177+ format = v .expressions [1 ].name if len (v .expressions ) > 1 else None
178+ elif isinstance (v , exp .Expression ):
179+ column = exp .column (v ) if isinstance (v , exp .Identifier ) else v
180+ format = None
181+ elif isinstance (v , str ):
182+ column = d .parse_one (v , dialect = dialect )
183+ column .meta .pop ("sql" )
184+ format = None
185+ elif isinstance (v , dict ):
186+ column_raw = v ["column" ]
187+ column = (
188+ d .parse_one (column_raw , dialect = dialect )
189+ if isinstance (column_raw , str )
190+ else column_raw
191+ )
192+ format = v .get ("format" )
193+ elif isinstance (v , TimeColumn ):
194+ return v
195+ else :
196+ raise ConfigError (f"Invalid time_column: '{ v } '." )
197+
198+ column = quote_identifiers (
199+ normalize_identifiers (column , dialect = dialect ), dialect = dialect
200+ )
201+ column .meta ["dialect" ] = dialect
171202
172- if isinstance (v , str ):
173- return TimeColumn (column = v )
174- return v
203+ return TimeColumn (column = column , format = format )
175204
176205 return field_validator ("time_column" , mode = "before" )(_time_column_validator )
177206
178207 @field_validator ("column" , mode = "before" )
179208 @classmethod
180- def _column_validator (cls , v : str ) -> str :
209+ def _column_validator (cls , v : t . Union [ str , exp . Expression ] ) -> exp . Expression :
181210 if not v :
182211 raise ConfigError ("Time Column cannot be empty." )
212+ if isinstance (v , str ):
213+ return exp .to_column (v )
183214 return v
184215
185216 @property
186- def expression (self ) -> exp .Column | exp . Tuple :
217+ def expression (self ) -> exp .Expression :
187218 """Convert this pydantic model into a time_column SQLGlot expression."""
188- column = exp .to_column (self .column )
189219 if not self .format :
190- return column
220+ return self . column
191221
192- return exp .Tuple (expressions = [column , exp .Literal .string (self .format )])
222+ return exp .Tuple (expressions = [self . column , exp .Literal .string (self .format )])
193223
194- def to_expression (self , dialect : str ) -> exp .Column | exp . Tuple :
224+ def to_expression (self , dialect : str ) -> exp .Expression :
195225 """Convert this pydantic model into a time_column SQLGlot expression."""
196- column = exp .to_column (self .column )
197226 if not self .format :
198- return column
227+ return self . column
199228
200229 return exp .Tuple (
201230 expressions = [
202- column ,
231+ self . column ,
203232 exp .Literal .string (
204233 format_time (self .format , d .Dialect .get_or_raise (dialect ).INVERSE_TIME_MAPPING )
205234 ),
@@ -211,6 +240,7 @@ def to_property(self, dialect: str = "") -> exp.Property:
211240
212241
213242class _Incremental (_ModelKind ):
243+ dialect : str = ""
214244 batch_size : t .Optional [SQLGlotPositiveInt ] = None
215245 lookback : t .Optional [SQLGlotPositiveInt ] = None
216246 forward_only : SQLGlotBool = False
@@ -335,6 +365,7 @@ class FullKind(_ModelKind):
335365
336366
337367class _SCDType2Kind (_ModelKind ):
368+ dialect : str = ""
338369 unique_key : SQLGlotListOfFields
339370 valid_from_name : SQLGlotString = "valid_from"
340371 valid_to_name : SQLGlotString = "valid_to"
@@ -344,6 +375,16 @@ class _SCDType2Kind(_ModelKind):
344375 forward_only : SQLGlotBool = True
345376 disable_restatement : SQLGlotBool = True
346377
378+ @field_validator ("time_data_type" , mode = "before" )
379+ @classmethod
380+ def _time_data_type_validator (
381+ cls , v : t .Union [str , exp .Expression ], values : t .Any
382+ ) -> exp .Expression :
383+ values = values if isinstance (values , dict ) else values .data
384+ if isinstance (v , exp .Expression ) and not isinstance (v , exp .DataType ):
385+ v = v .name
386+ return exp .DataType .build (v , dialect = values .get ("dialect" ))
387+
347388 @property
348389 def managed_columns (self ) -> t .Dict [str , exp .DataType ]:
349390 return {
@@ -425,20 +466,16 @@ def _model_kind_validator(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) ->
425466 if isinstance (v , d .ModelKind )
426467 else v
427468 )
428- time_data_type = props .pop ("time_data_type" , None )
429- if isinstance (time_data_type , exp .Expression ) and not isinstance (
430- time_data_type , exp .DataType
431- ):
432- time_data_type = time_data_type .name
433- if time_data_type :
434- props ["time_data_type" ] = exp .DataType .build (time_data_type , dialect = dialect )
435469 name = v .this if isinstance (v , d .ModelKind ) else props .get ("name" )
436470 # We want to ensure whatever name is provided to construct the class is the same name that will be
437471 # found inside the class itself in order to avoid a change during plan/apply for legacy aliases.
438472 # Ex: Pass in `SCD_TYPE_2` then we want to ensure we get `SCD_TYPE_2` as the kind name
439473 # instead of `SCD_TYPE_2_BY_TIME`.
440474 props ["name" ] = name
441- return model_kind_type_from_name (name )(** props )
475+ kind_type = model_kind_type_from_name (name )
476+ if "dialect" in kind_type .all_fields () and props .get ("dialect" ) is None :
477+ props ["dialect" ] = dialect
478+ return kind_type (** props )
442479
443480 name = (v .name if isinstance (v , exp .Expression ) else str (v )).upper ()
444481 return model_kind_type_from_name (name )(name = name ) # type: ignore
0 commit comments