Skip to content

Commit 01ccf3e

Browse files
committed
fix: properly load dbt relation type for get_relation()
related functions fixes #1987
1 parent 803f7d8 commit 01ccf3e

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dbt.adapters.base import BaseRelation
1919
from dbt.adapters.base.column import Column
2020
from dbt.adapters.base.impl import AdapterResponse
21+
from sqlmesh.core.engine_adapter.base import DataObject
2122
from sqlmesh.dbt.relation import Policy
2223

2324

@@ -259,6 +260,9 @@ def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]:
259260
if not self.engine_adapter.table_exists(mapped_table):
260261
return None
261262

263+
if do := self.engine_adapter.get_data_object(mapped_table):
264+
return self._data_object_to_relation(do)
265+
262266
return self._table_to_relation(mapped_table)
263267

264268
def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]:
@@ -269,24 +273,10 @@ def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseR
269273
return self.list_relations_without_caching(self._table_to_relation(target_schema))
270274

271275
def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]:
272-
from sqlmesh.dbt.relation import RelationType
273-
274276
schema = self._normalize(self._schema(schema_relation))
275277

276278
relations = [
277-
self.relation_type.create(
278-
database=do.catalog,
279-
schema=do.schema_name,
280-
identifier=do.name,
281-
quote_policy=self.quote_policy,
282-
# DBT relation types aren't snake case and instead just one word without spaces so we remove underscores
283-
type=(
284-
RelationType.External
285-
if do.type.is_unknown
286-
else RelationType(do.type.lower().replace("_", ""))
287-
),
288-
)
289-
for do in self.engine_adapter.get_data_objects(schema)
279+
self._data_object_to_relation(do) for do in self.engine_adapter.get_data_objects(schema)
290280
]
291281
return relations
292282

@@ -381,6 +371,22 @@ def _map_table_name(self, table: exp.Table) -> exp.Table:
381371
def _relation_to_table(self, relation: BaseRelation) -> exp.Table:
382372
return exp.to_table(relation.render(), dialect=self.project_dialect)
383373

374+
def _data_object_to_relation(self, data_object: DataObject) -> BaseRelation:
375+
from sqlmesh.dbt.relation import RelationType
376+
377+
return self.relation_type.create(
378+
database=data_object.catalog,
379+
schema=data_object.schema_name,
380+
identifier=data_object.name,
381+
quote_policy=self.quote_policy,
382+
# DBT relation types aren't snake case and instead just one word without spaces so we remove underscores
383+
type=(
384+
RelationType.External
385+
if data_object.type.is_unknown
386+
else RelationType(data_object.type.lower().replace("_", ""))
387+
),
388+
)
389+
384390
def _table_to_relation(self, table: exp.Table) -> BaseRelation:
385391
return self.relation_type.create(
386392
database=table.catalog or None,

tests/dbt/test_adapter.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
3838
engine_adapter.create_table(
3939
table_name="foo.another", columns_to_types={"col": exp.DataType.build("int")}
4040
)
41+
engine_adapter.create_view(
42+
view_name="foo.bar_view", query_or_df=parse_one("select * from foo.bar")
43+
)
4144
engine_adapter.create_table(
4245
table_name="ignored.ignore", columns_to_types={"col": exp.DataType.build("int")}
4346
)
@@ -46,11 +49,24 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
4649
renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar') }}")
4750
== '"memory"."foo"."bar"'
4851
)
52+
53+
assert (
54+
renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar').type }}")
55+
== "table"
56+
)
57+
58+
assert (
59+
renderer(
60+
"{{ adapter.get_relation(database=None, schema='foo', identifier='bar_view').type }}"
61+
)
62+
== "view"
63+
)
64+
4965
assert renderer(
5066
"{%- set relation = adapter.get_relation(database=None, schema='foo', identifier='bar') -%} {{ adapter.get_columns_in_relation(relation) }}"
5167
) == str([Column.from_description(name="baz", raw_data_type="INT")])
5268

53-
assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "2"
69+
assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "3"
5470

5571
assert renderer(
5672
"""

0 commit comments

Comments
 (0)