From 46b811b709303eb0bc35f3ede2a812b0961d9900 Mon Sep 17 00:00:00 2001 From: David Dai Date: Tue, 12 Aug 2025 16:48:44 -0700 Subject: [PATCH 1/2] fix: properly load dbt relation type for get_relation() related functions fixes #1987 --- sqlmesh/dbt/adapter.py | 40 +++++++++++++++++++++------------------ tests/dbt/test_adapter.py | 18 +++++++++++++++++- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 4178c960a7..00a1d86ba2 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -18,6 +18,7 @@ from dbt.adapters.base import BaseRelation from dbt.adapters.base.column import Column from dbt.adapters.base.impl import AdapterResponse + from sqlmesh.core.engine_adapter.base import DataObject from sqlmesh.dbt.relation import Policy @@ -256,10 +257,9 @@ def get_relation( def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]: mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation))) - if not self.engine_adapter.table_exists(mapped_table): - return None - return self._table_to_relation(mapped_table) + data_object = self.engine_adapter.get_data_object(mapped_table) + return self._data_object_to_relation(data_object) if data_object is not None else None def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]: target_schema = schema_(schema, catalog=database) @@ -269,24 +269,10 @@ def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseR return self.list_relations_without_caching(self._table_to_relation(target_schema)) def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]: - from sqlmesh.dbt.relation import RelationType - schema = self._normalize(self._schema(schema_relation)) relations = [ - self.relation_type.create( - database=do.catalog, - schema=do.schema_name, - identifier=do.name, - quote_policy=self.quote_policy, - # DBT relation types aren't snake case and instead just one word without spaces so we remove underscores - type=( - RelationType.External - if do.type.is_unknown - else RelationType(do.type.lower().replace("_", "")) - ), - ) - for do in self.engine_adapter.get_data_objects(schema) + self._data_object_to_relation(do) for do in self.engine_adapter.get_data_objects(schema) ] return relations @@ -401,6 +387,24 @@ def _map_table_name(self, table: exp.Table) -> exp.Table: def _relation_to_table(self, relation: BaseRelation) -> exp.Table: return exp.to_table(relation.render(), dialect=self.project_dialect) + def _data_object_to_relation(self, data_object: DataObject) -> BaseRelation: + from sqlmesh.dbt.relation import RelationType + + if data_object.type.is_unknown: + dbt_relation_type = RelationType.External + elif data_object.type.is_managed_table: + dbt_relation_type = RelationType.Table + else: + dbt_relation_type = RelationType(data_object.type.lower()) + + return self.relation_type.create( + database=data_object.catalog, + schema=data_object.schema_name, + identifier=data_object.name, + quote_policy=self.quote_policy, + type=dbt_relation_type, + ) + def _table_to_relation(self, table: exp.Table) -> BaseRelation: return self.relation_type.create( database=table.catalog or None, diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index 73a2e1f1f2..583c09dca3 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -38,6 +38,9 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla engine_adapter.create_table( table_name="foo.another", columns_to_types={"col": exp.DataType.build("int")} ) + engine_adapter.create_view( + view_name="foo.bar_view", query_or_df=parse_one("select * from foo.bar") + ) engine_adapter.create_table( table_name="ignored.ignore", columns_to_types={"col": exp.DataType.build("int")} ) @@ -46,11 +49,24 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar') }}") == '"memory"."foo"."bar"' ) + + assert ( + renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar').type }}") + == "table" + ) + + assert ( + renderer( + "{{ adapter.get_relation(database=None, schema='foo', identifier='bar_view').type }}" + ) + == "view" + ) + assert renderer( "{%- set relation = adapter.get_relation(database=None, schema='foo', identifier='bar') -%} {{ adapter.get_columns_in_relation(relation) }}" ) == str([Column.from_description(name="baz", raw_data_type="INT")]) - assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "2" + assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "3" assert renderer( """ From fa1db22ac4fda38862f5bd69bebe98956f06b46b Mon Sep 17 00:00:00 2001 From: David Dai Date: Thu, 14 Aug 2025 14:37:41 -0700 Subject: [PATCH 2/2] fix normalization test --- tests/dbt/test_adapter.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index 583c09dca3..0e1c953c35 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -126,26 +126,30 @@ def test_bigquery_get_columns_in_relation( def test_normalization( sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture ): + from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType + context = sushi_test_project.context assert context.target + data_object = DataObject(catalog="test", schema="bla", name="bob", type=DataObjectType.TABLE) # bla and bob will be normalized to lowercase since the target is duckdb adapter_mock = mocker.MagicMock() adapter_mock.default_catalog = "test" adapter_mock.dialect = "duckdb" - + adapter_mock.get_data_object.return_value = data_object duckdb_renderer = runtime_renderer(context, engine_adapter=adapter_mock) schema_bla = schema_("bla", "test", quoted=True) relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True) duckdb_renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}") - adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)]) + adapter_mock.get_data_object.assert_has_calls([call(relation_bla_bob)]) # bla and bob will be normalized to uppercase since the target is Snowflake, even though the default dialect is duckdb adapter_mock = mocker.MagicMock() adapter_mock.default_catalog = "test" adapter_mock.dialect = "snowflake" + adapter_mock.get_data_object.return_value = data_object context.target = SnowflakeConfig( account="test", user="test", @@ -160,10 +164,10 @@ def test_normalization( relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True) renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}") - adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)]) + adapter_mock.get_data_object.assert_has_calls([call(relation_bla_bob)]) renderer("{{ adapter.get_relation(database='custom_db', schema='bla', identifier='bob') }}") - adapter_mock.table_exists.assert_has_calls( + adapter_mock.get_data_object.assert_has_calls( [call(exp.table_("bob", db="bla", catalog="custom_db", quoted=True))] )