Skip to content

Commit cce46b5

Browse files
authored
fix: properly load dbt relation type for get_relation() and related functions (#5144)
1 parent 4f704a6 commit cce46b5

File tree

2 files changed

+47
-23
lines changed

2 files changed

+47
-23
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dbt.adapters.base import BaseRelation
1919
from dbt.adapters.base.column import Column
2020
from dbt.adapters.base.impl import AdapterResponse
21+
from sqlmesh.core.engine_adapter.base import DataObject
2122
from sqlmesh.dbt.relation import Policy
2223

2324

@@ -256,10 +257,9 @@ def get_relation(
256257

257258
def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]:
258259
mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation)))
259-
if not self.engine_adapter.table_exists(mapped_table):
260-
return None
261260

262-
return self._table_to_relation(mapped_table)
261+
data_object = self.engine_adapter.get_data_object(mapped_table)
262+
return self._data_object_to_relation(data_object) if data_object is not None else None
263263

264264
def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]:
265265
target_schema = schema_(schema, catalog=database)
@@ -269,24 +269,10 @@ def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseR
269269
return self.list_relations_without_caching(self._table_to_relation(target_schema))
270270

271271
def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]:
272-
from sqlmesh.dbt.relation import RelationType
273-
274272
schema = self._normalize(self._schema(schema_relation))
275273

276274
relations = [
277-
self.relation_type.create(
278-
database=do.catalog,
279-
schema=do.schema_name,
280-
identifier=do.name,
281-
quote_policy=self.quote_policy,
282-
# DBT relation types aren't snake case and instead just one word without spaces so we remove underscores
283-
type=(
284-
RelationType.External
285-
if do.type.is_unknown
286-
else RelationType(do.type.lower().replace("_", ""))
287-
),
288-
)
289-
for do in self.engine_adapter.get_data_objects(schema)
275+
self._data_object_to_relation(do) for do in self.engine_adapter.get_data_objects(schema)
290276
]
291277
return relations
292278

@@ -401,6 +387,24 @@ def _map_table_name(self, table: exp.Table) -> exp.Table:
401387
def _relation_to_table(self, relation: BaseRelation) -> exp.Table:
402388
return exp.to_table(relation.render(), dialect=self.project_dialect)
403389

390+
def _data_object_to_relation(self, data_object: DataObject) -> BaseRelation:
391+
from sqlmesh.dbt.relation import RelationType
392+
393+
if data_object.type.is_unknown:
394+
dbt_relation_type = RelationType.External
395+
elif data_object.type.is_managed_table:
396+
dbt_relation_type = RelationType.Table
397+
else:
398+
dbt_relation_type = RelationType(data_object.type.lower())
399+
400+
return self.relation_type.create(
401+
database=data_object.catalog,
402+
schema=data_object.schema_name,
403+
identifier=data_object.name,
404+
quote_policy=self.quote_policy,
405+
type=dbt_relation_type,
406+
)
407+
404408
def _table_to_relation(self, table: exp.Table) -> BaseRelation:
405409
return self.relation_type.create(
406410
database=table.catalog or None,

tests/dbt/test_adapter.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
3636
engine_adapter.create_table(
3737
table_name="foo.another", columns_to_types={"col": exp.DataType.build("int")}
3838
)
39+
engine_adapter.create_view(
40+
view_name="foo.bar_view", query_or_df=parse_one("select * from foo.bar")
41+
)
3942
engine_adapter.create_table(
4043
table_name="ignored.ignore", columns_to_types={"col": exp.DataType.build("int")}
4144
)
@@ -44,11 +47,24 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
4447
renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar') }}")
4548
== '"memory"."foo"."bar"'
4649
)
50+
51+
assert (
52+
renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar').type }}")
53+
== "table"
54+
)
55+
56+
assert (
57+
renderer(
58+
"{{ adapter.get_relation(database=None, schema='foo', identifier='bar_view').type }}"
59+
)
60+
== "view"
61+
)
62+
4763
assert renderer(
4864
"{%- set relation = adapter.get_relation(database=None, schema='foo', identifier='bar') -%} {{ adapter.get_columns_in_relation(relation) }}"
4965
) == str([Column.from_description(name="baz", raw_data_type="INT")])
5066

51-
assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "2"
67+
assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "3"
5268

5369
assert renderer(
5470
"""
@@ -108,26 +124,30 @@ def test_bigquery_get_columns_in_relation(
108124
def test_normalization(
109125
sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture
110126
):
127+
from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType
128+
111129
context = sushi_test_project.context
112130
assert context.target
131+
data_object = DataObject(catalog="test", schema="bla", name="bob", type=DataObjectType.TABLE)
113132

114133
# bla and bob will be normalized to lowercase since the target is duckdb
115134
adapter_mock = mocker.MagicMock()
116135
adapter_mock.default_catalog = "test"
117136
adapter_mock.dialect = "duckdb"
118-
137+
adapter_mock.get_data_object.return_value = data_object
119138
duckdb_renderer = runtime_renderer(context, engine_adapter=adapter_mock)
120139

121140
schema_bla = schema_("bla", "test", quoted=True)
122141
relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True)
123142

124143
duckdb_renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}")
125-
adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)])
144+
adapter_mock.get_data_object.assert_has_calls([call(relation_bla_bob)])
126145

127146
# bla and bob will be normalized to uppercase since the target is Snowflake, even though the default dialect is duckdb
128147
adapter_mock = mocker.MagicMock()
129148
adapter_mock.default_catalog = "test"
130149
adapter_mock.dialect = "snowflake"
150+
adapter_mock.get_data_object.return_value = data_object
131151
context.target = SnowflakeConfig(
132152
account="test",
133153
user="test",
@@ -142,10 +162,10 @@ def test_normalization(
142162
relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True)
143163

144164
renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}")
145-
adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)])
165+
adapter_mock.get_data_object.assert_has_calls([call(relation_bla_bob)])
146166

147167
renderer("{{ adapter.get_relation(database='custom_db', schema='bla', identifier='bob') }}")
148-
adapter_mock.table_exists.assert_has_calls(
168+
adapter_mock.get_data_object.assert_has_calls(
149169
[call(exp.table_("bob", db="bla", catalog="custom_db", quoted=True))]
150170
)
151171

0 commit comments

Comments
 (0)