Skip to content

Commit 8514949

Browse files
authored
[wip] add iceberg options to bigquery
1 parent a255e17 commit 8514949

File tree

1 file changed

+168
-29
lines changed

1 file changed

+168
-29
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 168 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import typing as t
5-
from collections import defaultdict
5+
from collections import defaultdict, OrderedDict
66

77
from sqlglot import exp, parse_one
88
from sqlglot.transforms import remove_precision_parameterized_types
@@ -169,18 +169,17 @@ def _df_to_source_queries(
169169
)
170170

171171
def query_factory() -> Query:
172-
ordered_df = df[list(source_columns_to_types)]
173-
if bigframes_pd and isinstance(ordered_df, bigframes_pd.DataFrame):
174-
ordered_df.to_gbq(
172+
if bigframes_pd and isinstance(df, bigframes_pd.DataFrame):
173+
df.to_gbq(
175174
f"{temp_bq_table.project}.{temp_bq_table.dataset_id}.{temp_bq_table.table_id}",
176175
if_exists="replace",
177176
)
178177
elif not self.table_exists(temp_table):
179178
# Make mypy happy
180-
assert isinstance(ordered_df, pd.DataFrame)
179+
assert isinstance(df, pd.DataFrame)
181180
self._db_call(self.client.create_table, table=temp_bq_table, exists_ok=False)
182181
result = self.__load_pandas_to_table(
183-
temp_bq_table, ordered_df, source_columns_to_types, replace=False
182+
temp_bq_table, df, source_columns_to_types, replace=False
184183
)
185184
if result.errors:
186185
raise SQLMeshError(result.errors)
@@ -755,28 +754,6 @@ def table_exists(self, table_name: TableName) -> bool:
755754
except NotFound:
756755
return False
757756

758-
def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
759-
from sqlmesh.utils.date import to_timestamp
760-
761-
datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list)
762-
for table_name in table_names:
763-
table = exp.to_table(table_name)
764-
datasets_to_tables[table.db].append(table.name)
765-
766-
results = []
767-
768-
for dataset, tables in datasets_to_tables.items():
769-
query = (
770-
f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE "
771-
)
772-
for i, table_name in enumerate(tables):
773-
query += f"TABLE_ID = '{table_name}'"
774-
if i < len(tables) - 1:
775-
query += " OR "
776-
results.extend(self.fetchall(query))
777-
778-
return [to_timestamp(row[0]) for row in results]
779-
780757
def _get_table(self, table_name: TableName) -> BigQueryTable:
781758
"""
782759
Returns a BigQueryTable object for the given table name.
@@ -891,6 +868,60 @@ def _build_partitioned_by_exp(
891868

892869
return exp.PartitionedByProperty(this=this)
893870

871+
def _create_table(
872+
self,
873+
table_name_or_schema: t.Union[exp.Schema, TableName],
874+
expression: t.Optional[exp.Expression],
875+
exists: bool = True,
876+
replace: bool = False,
877+
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
878+
table_description: t.Optional[str] = None,
879+
column_descriptions: t.Optional[t.Dict[str, str]] = None,
880+
table_kind: t.Optional[str] = None,
881+
track_rows_processed: bool = True,
882+
**kwargs: t.Any,
883+
) -> None:
884+
normalized_properties, connection_property = self._prepare_create_table_properties(
885+
kwargs.get("table_properties"),
886+
kwargs.get("table_format"),
887+
kwargs.get("storage_format"),
888+
)
889+
kwargs["table_properties"] = normalized_properties
890+
891+
if connection_property is None:
892+
super()._create_table(
893+
table_name_or_schema,
894+
expression,
895+
exists=exists,
896+
replace=replace,
897+
target_columns_to_types=target_columns_to_types,
898+
table_description=table_description,
899+
column_descriptions=column_descriptions,
900+
table_kind=table_kind,
901+
track_rows_processed=track_rows_processed,
902+
**kwargs,
903+
)
904+
return
905+
906+
create_expression = self._build_create_table_exp(
907+
table_name_or_schema,
908+
expression=expression,
909+
exists=exists,
910+
replace=replace,
911+
target_columns_to_types=target_columns_to_types,
912+
table_description=(
913+
table_description
914+
if self.COMMENT_CREATION_TABLE.supports_schema_def and self.comments_enabled
915+
else None
916+
),
917+
table_kind=table_kind,
918+
**kwargs,
919+
)
920+
sql = self._to_sql(create_expression)
921+
connection_sql = self._connection_clause_sql(connection_property)
922+
sql = self._inject_connection_clause(sql, connection_sql)
923+
self.execute(sql, track_rows_processed=track_rows_processed)
924+
894925
def _build_table_properties_exp(
895926
self,
896927
catalog_name: t.Optional[str] = None,
@@ -926,12 +957,120 @@ def _build_table_properties_exp(
926957
),
927958
)
928959

929-
properties.extend(self._table_or_view_properties_to_expressions(table_properties))
960+
if table_properties:
961+
for key, value in table_properties.items():
962+
properties.append(exp.Property(this=key, value=value.copy()))
930963

931964
if properties:
932965
return exp.Properties(expressions=properties)
933966
return None
934967

968+
def _prepare_create_table_properties(
969+
self,
970+
table_properties: t.Optional[t.Dict[str, exp.Expression]],
971+
table_format: t.Optional[str],
972+
storage_format: t.Optional[str],
973+
) -> t.Tuple[OrderedDict[str, exp.Expression], t.Optional[exp.Expression]]:
974+
normalized_properties: OrderedDict[str, exp.Expression] = OrderedDict()
975+
connection_property: t.Optional[exp.Expression] = None
976+
977+
if table_properties:
978+
for key, value in table_properties.items():
979+
if value is None:
980+
continue
981+
key_lower = key.lower()
982+
if key_lower in {"connection", "with_connection"}:
983+
connection_property = value
984+
continue
985+
# Reinsert properties with the latest casing while preserving order
986+
for existing_key in list(normalized_properties.keys()):
987+
if existing_key.lower() == key_lower:
988+
normalized_properties.pop(existing_key)
989+
break
990+
normalized_properties[key] = value.copy()
991+
992+
def _get_property(name: str) -> t.Optional[exp.Expression]:
993+
for existing_key, value in normalized_properties.items():
994+
if existing_key.lower() == name:
995+
return value
996+
return None
997+
998+
def _set_property(name: str, expression: exp.Expression) -> None:
999+
for existing_key in list(normalized_properties.keys()):
1000+
if existing_key.lower() == name:
1001+
normalized_properties.pop(existing_key)
1002+
break
1003+
normalized_properties[name] = expression
1004+
1005+
def _has_property(name: str) -> bool:
1006+
return any(existing_key.lower() == name for existing_key in normalized_properties)
1007+
1008+
normalized_table_format = table_format.lower() if table_format else None
1009+
if not normalized_table_format:
1010+
existing_table_format = _get_property("table_format")
1011+
if isinstance(existing_table_format, exp.Literal) and existing_table_format.is_string:
1012+
normalized_table_format = existing_table_format.this.lower()
1013+
is_iceberg = normalized_table_format == "iceberg"
1014+
1015+
if is_iceberg:
1016+
table_format_expression = self._ensure_upper_string_literal(
1017+
_get_property("table_format"),
1018+
default=normalized_table_format or "iceberg",
1019+
)
1020+
_set_property("table_format", table_format_expression)
1021+
1022+
file_format_expression = self._ensure_upper_string_literal(
1023+
_get_property("file_format"),
1024+
default=storage_format or "PARQUET",
1025+
)
1026+
_set_property("file_format", file_format_expression)
1027+
1028+
if not _has_property("storage_uri"):
1029+
raise SQLMeshError(
1030+
"BigQuery Iceberg tables require `storage_uri` to be set in physical_properties."
1031+
)
1032+
1033+
if connection_property is None:
1034+
raise SQLMeshError(
1035+
"BigQuery Iceberg tables require a `connection` entry in physical_properties."
1036+
)
1037+
1038+
return normalized_properties, connection_property
1039+
1040+
def _ensure_upper_string_literal(
1041+
self,
1042+
expression: t.Optional[exp.Expression],
1043+
default: str,
1044+
) -> exp.Expression:
1045+
if expression is None:
1046+
return exp.Literal.string(default.upper())
1047+
1048+
expression = expression.copy()
1049+
if isinstance(expression, exp.Literal) and expression.is_string:
1050+
return exp.Literal.string(expression.this.upper())
1051+
return expression
1052+
1053+
def _connection_clause_sql(self, connection_expression: exp.Expression) -> str:
1054+
expression = connection_expression.copy()
1055+
if isinstance(expression, exp.Literal) and expression.is_string:
1056+
value = expression.this.strip()
1057+
if value.upper() == "DEFAULT":
1058+
return "DEFAULT"
1059+
return exp.to_identifier(value, quoted=True).sql(dialect=self.dialect)
1060+
1061+
return self._to_sql(expression)
1062+
1063+
@staticmethod
1064+
def _inject_connection_clause(create_sql: str, connection_sql: str) -> str:
1065+
parts = create_sql.split("OPTIONS", 1)
1066+
if len(parts) == 2:
1067+
prefix, suffix = parts
1068+
if not prefix.endswith(" "):
1069+
prefix = f"{prefix} "
1070+
return f"{prefix}WITH CONNECTION {connection_sql} OPTIONS{suffix}"
1071+
separator = " " if not create_sql.endswith(" ") else ""
1072+
return f"{create_sql}{separator}WITH CONNECTION {connection_sql}"
1073+
9351074
def _build_column_def(
9361075
self,
9371076
col_name: str,

0 commit comments

Comments
 (0)