Skip to content

Commit 385e079

Browse files
authored
Fix!: Quote identifiers when deserializing model SQL attributes (#2267)
1 parent 071d6a8 commit 385e079

File tree

15 files changed

+167
-116
lines changed

15 files changed

+167
-116
lines changed

sqlmesh/core/audit/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def render_query(
217217

218218
node = t.cast(_Model, node)
219219
if node.time_column:
220-
where = exp.column(node.time_column.column).between(
220+
where = node.time_column.column.between(
221221
node.convert_to_time_column(start or c.EPOCH, columns_to_types),
222222
node.convert_to_time_column(end or c.EPOCH, columns_to_types),
223223
)

sqlmesh/core/engine_adapter/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,7 @@ def insert_overwrite_by_time_partition(
10711071
time_formatter: t.Callable[
10721072
[TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expression
10731073
],
1074-
time_column: TimeColumn | exp.Column | str,
1074+
time_column: TimeColumn | exp.Expression | str,
10751075
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
10761076
**kwargs: t.Any,
10771077
) -> None:
@@ -1083,7 +1083,7 @@ def insert_overwrite_by_time_partition(
10831083
if isinstance(time_column, TimeColumn):
10841084
time_column = time_column.column
10851085
where = exp.Between(
1086-
this=exp.to_column(time_column),
1086+
this=exp.to_column(time_column) if isinstance(time_column, str) else time_column,
10871087
low=low,
10881088
high=high,
10891089
)
@@ -1939,10 +1939,12 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An
19391939
**kwargs,
19401940
}
19411941

1942+
expression = expression.copy()
1943+
19421944
if quote:
19431945
quote_identifiers(expression)
19441946

1945-
return expression.sql(**sql_gen_kwargs) # type: ignore
1947+
return expression.sql(**sql_gen_kwargs, copy=False) # type: ignore
19461948

19471949
def _get_data_objects(
19481950
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None

sqlmesh/core/model/definition.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,12 +509,12 @@ def convert_to_time_column(
509509
if columns_to_types is None:
510510
columns_to_types = self.columns_to_types_or_raise
511511

512-
if self.time_column.column not in columns_to_types:
512+
if self.time_column.column.name not in columns_to_types:
513513
raise ConfigError(
514-
f"Time column '{self.time_column.column}' not found in model '{self.name}'."
514+
f"Time column '{self.time_column.column.sql(dialect=self.dialect)}' not found in model '{self.name}'."
515515
)
516516

517-
time_column_type = columns_to_types[self.time_column.column]
517+
time_column_type = columns_to_types[self.time_column.column.name]
518518

519519
return to_time_column(time, time_column_type, self.time_column.format)
520520
return exp.convert(time)
@@ -726,7 +726,7 @@ def _data_hash_values(self) -> t.List[str]:
726726
data.append(gen(value))
727727

728728
if isinstance(self.kind, IncrementalByTimeRangeKind):
729-
data.append(self.kind.time_column.column)
729+
data.append(gen(self.kind.time_column.column))
730730
data.append(self.kind.time_column.format)
731731
elif isinstance(self.kind, IncrementalByUniqueKeyKind):
732732
data.extend((gen(k) for k in self.kind.unique_key))

sqlmesh/core/model/kind.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from pydantic import Field
88
from sqlglot import exp
9+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
10+
from sqlglot.optimizer.qualify_columns import quote_identifiers
911
from sqlglot.time import format_time
1012

1113
from sqlmesh.core import dialect as d
@@ -31,6 +33,10 @@
3133
from typing_extensions import Annotated, Literal
3234

3335

36+
if t.TYPE_CHECKING:
37+
MODEL_KIND = t.TypeVar("MODEL_KIND", bound="_ModelKind")
38+
39+
3440
class ModelKindMixin:
3541
@property
3642
def model_kind_name(self) -> t.Optional[ModelKindName]:
@@ -153,53 +159,76 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
153159

154160

155161
class TimeColumn(PydanticModel):
156-
column: str
162+
column: exp.Expression
157163
format: t.Optional[str] = None
158164

159165
@classmethod
160166
def validator(cls) -> classmethod:
161-
def _time_column_validator(v: t.Any) -> TimeColumn:
167+
def _time_column_validator(v: t.Any, values: t.Any) -> TimeColumn:
168+
values = values if isinstance(values, dict) else values.data
169+
dialect = values.get("dialect")
162170
if isinstance(v, exp.Tuple):
163-
kwargs = {
164-
key: v.expressions[i].name
165-
for i, key in enumerate(("column", "format")[: len(v.expressions)])
166-
}
167-
return TimeColumn(**kwargs)
168-
169-
if isinstance(v, exp.Expression):
170-
return TimeColumn(column=v.name)
171+
column_expr = v.expressions[0]
172+
column = (
173+
exp.column(column_expr)
174+
if isinstance(column_expr, exp.Identifier)
175+
else column_expr
176+
)
177+
format = v.expressions[1].name if len(v.expressions) > 1 else None
178+
elif isinstance(v, exp.Expression):
179+
column = exp.column(v) if isinstance(v, exp.Identifier) else v
180+
format = None
181+
elif isinstance(v, str):
182+
column = d.parse_one(v, dialect=dialect)
183+
column.meta.pop("sql")
184+
format = None
185+
elif isinstance(v, dict):
186+
column_raw = v["column"]
187+
column = (
188+
d.parse_one(column_raw, dialect=dialect)
189+
if isinstance(column_raw, str)
190+
else column_raw
191+
)
192+
format = v.get("format")
193+
elif isinstance(v, TimeColumn):
194+
return v
195+
else:
196+
raise ConfigError(f"Invalid time_column: '{v}'.")
197+
198+
column = quote_identifiers(
199+
normalize_identifiers(column, dialect=dialect), dialect=dialect
200+
)
201+
column.meta["dialect"] = dialect
171202

172-
if isinstance(v, str):
173-
return TimeColumn(column=v)
174-
return v
203+
return TimeColumn(column=column, format=format)
175204

176205
return field_validator("time_column", mode="before")(_time_column_validator)
177206

178207
@field_validator("column", mode="before")
179208
@classmethod
180-
def _column_validator(cls, v: str) -> str:
209+
def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression:
181210
if not v:
182211
raise ConfigError("Time Column cannot be empty.")
212+
if isinstance(v, str):
213+
return exp.to_column(v)
183214
return v
184215

185216
@property
186-
def expression(self) -> exp.Column | exp.Tuple:
217+
def expression(self) -> exp.Expression:
187218
"""Convert this pydantic model into a time_column SQLGlot expression."""
188-
column = exp.to_column(self.column)
189219
if not self.format:
190-
return column
220+
return self.column
191221

192-
return exp.Tuple(expressions=[column, exp.Literal.string(self.format)])
222+
return exp.Tuple(expressions=[self.column, exp.Literal.string(self.format)])
193223

194-
def to_expression(self, dialect: str) -> exp.Column | exp.Tuple:
224+
def to_expression(self, dialect: str) -> exp.Expression:
195225
"""Convert this pydantic model into a time_column SQLGlot expression."""
196-
column = exp.to_column(self.column)
197226
if not self.format:
198-
return column
227+
return self.column
199228

200229
return exp.Tuple(
201230
expressions=[
202-
column,
231+
self.column,
203232
exp.Literal.string(
204233
format_time(self.format, d.Dialect.get_or_raise(dialect).INVERSE_TIME_MAPPING)
205234
),
@@ -211,6 +240,7 @@ def to_property(self, dialect: str = "") -> exp.Property:
211240

212241

213242
class _Incremental(_ModelKind):
243+
dialect: str = ""
214244
batch_size: t.Optional[SQLGlotPositiveInt] = None
215245
lookback: t.Optional[SQLGlotPositiveInt] = None
216246
forward_only: SQLGlotBool = False
@@ -335,6 +365,7 @@ class FullKind(_ModelKind):
335365

336366

337367
class _SCDType2Kind(_ModelKind):
368+
dialect: str = ""
338369
unique_key: SQLGlotListOfFields
339370
valid_from_name: SQLGlotString = "valid_from"
340371
valid_to_name: SQLGlotString = "valid_to"
@@ -344,6 +375,16 @@ class _SCDType2Kind(_ModelKind):
344375
forward_only: SQLGlotBool = True
345376
disable_restatement: SQLGlotBool = True
346377

378+
@field_validator("time_data_type", mode="before")
379+
@classmethod
380+
def _time_data_type_validator(
381+
cls, v: t.Union[str, exp.Expression], values: t.Any
382+
) -> exp.Expression:
383+
values = values if isinstance(values, dict) else values.data
384+
if isinstance(v, exp.Expression) and not isinstance(v, exp.DataType):
385+
v = v.name
386+
return exp.DataType.build(v, dialect=values.get("dialect"))
387+
347388
@property
348389
def managed_columns(self) -> t.Dict[str, exp.DataType]:
349390
return {
@@ -425,20 +466,16 @@ def _model_kind_validator(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) ->
425466
if isinstance(v, d.ModelKind)
426467
else v
427468
)
428-
time_data_type = props.pop("time_data_type", None)
429-
if isinstance(time_data_type, exp.Expression) and not isinstance(
430-
time_data_type, exp.DataType
431-
):
432-
time_data_type = time_data_type.name
433-
if time_data_type:
434-
props["time_data_type"] = exp.DataType.build(time_data_type, dialect=dialect)
435469
name = v.this if isinstance(v, d.ModelKind) else props.get("name")
436470
# We want to ensure whatever name is provided to construct the class is the same name that will be
437471
# found inside the class itself in order to avoid a change during plan/apply for legacy aliases.
438472
# Ex: Pass in `SCD_TYPE_2` then we want to ensure we get `SCD_TYPE_2` as the kind name
439473
# instead of `SCD_TYPE_2_BY_TIME`.
440474
props["name"] = name
441-
return model_kind_type_from_name(name)(**props)
475+
kind_type = model_kind_type_from_name(name)
476+
if "dialect" in kind_type.all_fields() and props.get("dialect") is None:
477+
props["dialect"] = dialect
478+
return kind_type(**props)
442479

443480
name = (v.name if isinstance(v, exp.Expression) else str(v)).upper()
444481
return model_kind_type_from_name(name)(name=name) # type: ignore

sqlmesh/core/model/meta.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -267,19 +267,6 @@ def _kind_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
267267
for field in ("partitioned_by_", "clustered_by"):
268268
if values.get(field) and not kind.is_materialized:
269269
raise ValueError(f"{field} field cannot be set for {kind} models")
270-
271-
dialect = values.get("dialect")
272-
273-
if hasattr(kind, "time_column"):
274-
kind.time_column.column = normalize_identifiers(
275-
kind.time_column.column, dialect=dialect
276-
).name
277-
278-
if hasattr(kind, "unique_key"):
279-
kind.unique_key = [
280-
normalize_identifiers(key, dialect=dialect) for key in kind.unique_key
281-
]
282-
283270
return values
284271

285272
@property
@@ -298,9 +285,9 @@ def unique_key(self) -> t.List[exp.Expression]:
298285
@property
299286
def partitioned_by(self) -> t.List[exp.Expression]:
300287
if self.time_column and self.time_column.column not in [
301-
col.name for col in self._partition_by_columns
288+
col for col in self._partition_by_columns
302289
]:
303-
return [*[exp.to_column(self.time_column.column)], *self.partitioned_by_]
290+
return [self.time_column.column, *self.partitioned_by_]
304291
return self.partitioned_by_
305292

306293
@property

sqlmesh/dbt/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def model_kind(self, context: DbtContext) -> ModelKind:
184184
if materialization == Materialization.VIEW:
185185
return ViewKind()
186186
if materialization == Materialization.INCREMENTAL:
187-
incremental_kwargs = {}
187+
incremental_kwargs = {"dialect": context.dialect}
188188
for field in ("batch_size", "lookback", "forward_only", "disable_restatement"):
189189
field_val = getattr(self, field, None) or self.meta.get(field, None)
190190
if field_val:
@@ -243,6 +243,7 @@ def model_kind(self, context: DbtContext) -> ModelKind:
243243
f"{self.canonical_name(context)}: SQLMesh snapshot strategy is required for snapshot materialization."
244244
)
245245
shared_kwargs = {
246+
"dialect": context.dialect,
246247
"unique_key": self.unique_key,
247248
"invalidate_hard_deletes": self.invalidate_hard_deletes,
248249
"valid_from_name": "dbt_valid_from",
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Quoted identifiers in model SQL attributes."""
2+
3+
4+
def migrate(state_sync, **kwargs): # type: ignore
5+
pass

sqlmesh/utils/pydantic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlglot import exp, parse_one
1111
from sqlglot.helper import ensure_list
1212
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
13+
from sqlglot.optimizer.qualify_columns import quote_identifiers
1314

1415
from sqlmesh.core import dialect as d
1516
from sqlmesh.utils import str_to_bool
@@ -309,10 +310,8 @@ def _get_fields(
309310
results = []
310311

311312
for expr in expressions:
312-
expr = normalize_identifiers(
313-
exp.column(expr) if isinstance(expr, exp.Identifier) else expr,
314-
dialect=dialect,
315-
)
313+
expr = exp.column(expr) if isinstance(expr, exp.Identifier) else expr
314+
expr = quote_identifiers(normalize_identifiers(expr, dialect=dialect), dialect=dialect)
316315
expr.meta["dialect"] = dialect
317316
results.append(expr)
318317

0 commit comments

Comments
 (0)