Skip to content

Commit 0c70406

Browse files
authored
Fix!: sqlmesh.dbt.adapter.RuntimeAdapter.get_columns_in_relation() (#5115)
1 parent ee2a8bd commit 0c70406

File tree

5 files changed

+105
-4
lines changed

5 files changed

+105
-4
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ filterwarnings = [
273273
]
274274
retry_delay = 10
275275

276+
[tool.ruff]
277+
line-length = 100
278+
276279
[tool.ruff.lint]
277280
select = [
278281
"F401",

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,13 @@ def create_schema(
292292
raise
293293
logger.warning("Failed to create schema '%s': %s", schema_name, e)
294294

295+
def get_bq_schema(self, table_name: TableName) -> t.List[bigquery.SchemaField]:
296+
table = exp.to_table(table_name)
297+
if len(table.parts) == 3 and "." in table.name:
298+
self.execute(exp.select("*").from_(table).limit(0))
299+
return self._query_job._query_results.schema
300+
return self._get_table(table).schema
301+
295302
def columns(
296303
self, table_name: TableName, include_pseudo_columns: bool = False
297304
) -> t.Dict[str, exp.DataType]:

sqlmesh/dbt/adapter.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,31 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.Lis
291291
return relations
292292

293293
def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]:
294-
from dbt.adapters.base.column import Column
295-
296294
mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation)))
295+
296+
if self.project_dialect == "bigquery":
297+
# dbt.adapters.bigquery.column.BigQueryColumn has a different constructor signature
298+
# We need to use BigQueryColumn.create_from_field() to create the column instead
299+
if (
300+
hasattr(self.column_type, "create_from_field")
301+
and callable(getattr(self.column_type, "create_from_field"))
302+
and hasattr(self.engine_adapter, "get_bq_schema")
303+
and callable(getattr(self.engine_adapter, "get_bq_schema"))
304+
):
305+
return [
306+
self.column_type.create_from_field(field) # type: ignore
307+
for field in self.engine_adapter.get_bq_schema(mapped_table) # type: ignore
308+
]
309+
from dbt.adapters.base.column import Column
310+
311+
return [
312+
Column.from_description(
313+
name=name, raw_data_type=dtype.sql(dialect=self.project_dialect)
314+
)
315+
for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items()
316+
]
297317
return [
298-
Column.from_description(
318+
self.column_type.from_description(
299319
name=name, raw_data_type=dtype.sql(dialect=self.project_dialect)
300320
)
301321
for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items()

tests/core/engine_adapter/integration/test_integration_bigquery.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,39 @@ def test_compare_nested_values_in_table_diff(ctx: TestContext):
341341
ctx.engine_adapter.drop_table(target_table)
342342

343343

344+
def test_get_bq_schema(ctx: TestContext, engine_adapter: BigQueryEngineAdapter):
345+
from google.cloud.bigquery import SchemaField
346+
347+
table = ctx.table("test")
348+
349+
engine_adapter.execute(f"""
350+
CREATE TABLE {table.sql(dialect=ctx.dialect)} (
351+
id STRING NOT NULL,
352+
user_data STRUCT<id STRING NOT NULL, name STRING NOT NULL, address STRING>,
353+
tags ARRAY<STRING>,
354+
score NUMERIC,
355+
created_at DATETIME
356+
)
357+
""")
358+
359+
bg_schema = engine_adapter.get_bq_schema(table)
360+
assert len(bg_schema) == 5
361+
assert bg_schema[0] == SchemaField(name="id", field_type="STRING", mode="REQUIRED")
362+
assert bg_schema[1] == SchemaField(
363+
name="user_data",
364+
field_type="RECORD",
365+
mode="NULLABLE",
366+
fields=[
367+
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
368+
SchemaField(name="name", field_type="STRING", mode="REQUIRED"),
369+
SchemaField(name="address", field_type="STRING", mode="NULLABLE"),
370+
],
371+
)
372+
assert bg_schema[2] == SchemaField(name="tags", field_type="STRING", mode="REPEATED")
373+
assert bg_schema[3] == SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE")
374+
assert bg_schema[4] == SchemaField(name="created_at", field_type="DATETIME", mode="NULLABLE")
375+
376+
344377
def test_column_types(ctx: TestContext):
345378
model_name = ctx.table("test")
346379
sqlmesh = ctx.create_context()

tests/dbt/test_adapter.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sqlmesh.dbt.adapter import ParsetimeAdapter
1818
from sqlmesh.dbt.project import Project
1919
from sqlmesh.dbt.relation import Policy
20-
from sqlmesh.dbt.target import SnowflakeConfig
20+
from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig
2121
from sqlmesh.utils.errors import ConfigError
2222
from sqlmesh.utils.jinja import JinjaMacroRegistry
2323

@@ -68,6 +68,44 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
6868
)
6969

7070

71+
def test_bigquery_get_columns_in_relation(
72+
sushi_test_project: Project,
73+
runtime_renderer: t.Callable,
74+
mocker: MockerFixture,
75+
):
76+
from dbt.adapters.bigquery import BigQueryColumn
77+
from google.cloud.bigquery import SchemaField
78+
79+
context = sushi_test_project.context
80+
context.target = BigQueryConfig(name="test", schema="test", database="test")
81+
82+
adapter_mock = mocker.MagicMock()
83+
adapter_mock.default_catalog = "test"
84+
adapter_mock.dialect = "bigquery"
85+
table_schema = [
86+
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
87+
SchemaField(
88+
name="user_data",
89+
field_type="RECORD",
90+
mode="NULLABLE",
91+
fields=[
92+
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
93+
SchemaField(name="name", field_type="STRING", mode="REQUIRED"),
94+
SchemaField(name="address", field_type="STRING", mode="NULLABLE"),
95+
],
96+
),
97+
SchemaField(name="tags", field_type="STRING", mode="REPEATED"),
98+
SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE"),
99+
SchemaField(name="created_at", field_type="TIMESTAMP", mode="NULLABLE"),
100+
]
101+
adapter_mock.get_bq_schema.return_value = table_schema
102+
renderer = runtime_renderer(context, engine_adapter=adapter_mock, dialect="bigquery")
103+
assert renderer(
104+
"{%- set relation = api.Relation.create(database='test', schema='test', identifier='test_table') -%}"
105+
"{{ adapter.get_columns_in_relation(relation) }}"
106+
) == str([BigQueryColumn.create_from_field(field) for field in table_schema])
107+
108+
71109
@pytest.mark.cicdonly
72110
def test_normalization(
73111
sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture

0 commit comments

Comments
 (0)