Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions sqlmesh/dbt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dbt.adapters.base import BaseRelation
from dbt.adapters.base.column import Column
from dbt.adapters.base.impl import AdapterResponse
from sqlmesh.core.engine_adapter.base import DataObject
from sqlmesh.dbt.relation import Policy


Expand Down Expand Up @@ -256,10 +257,9 @@ def get_relation(

def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]:
mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation)))
if not self.engine_adapter.table_exists(mapped_table):
return None

return self._table_to_relation(mapped_table)
data_object = self.engine_adapter.get_data_object(mapped_table)
return self._data_object_to_relation(data_object) if data_object is not None else None

def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]:
target_schema = schema_(schema, catalog=database)
Expand All @@ -269,24 +269,10 @@ def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseR
return self.list_relations_without_caching(self._table_to_relation(target_schema))

def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]:
from sqlmesh.dbt.relation import RelationType

schema = self._normalize(self._schema(schema_relation))

relations = [
self.relation_type.create(
database=do.catalog,
schema=do.schema_name,
identifier=do.name,
quote_policy=self.quote_policy,
# DBT relation types aren't snake case and instead just one word without spaces so we remove underscores
type=(
RelationType.External
if do.type.is_unknown
else RelationType(do.type.lower().replace("_", ""))
),
)
for do in self.engine_adapter.get_data_objects(schema)
self._data_object_to_relation(do) for do in self.engine_adapter.get_data_objects(schema)
]
return relations

Expand Down Expand Up @@ -401,6 +387,24 @@ def _map_table_name(self, table: exp.Table) -> exp.Table:
def _relation_to_table(self, relation: BaseRelation) -> exp.Table:
return exp.to_table(relation.render(), dialect=self.project_dialect)

def _data_object_to_relation(self, data_object: DataObject) -> BaseRelation:
from sqlmesh.dbt.relation import RelationType

if data_object.type.is_unknown:
dbt_relation_type = RelationType.External
elif data_object.type.is_managed_table:
dbt_relation_type = RelationType.Table
else:
dbt_relation_type = RelationType(data_object.type.lower())

return self.relation_type.create(
database=data_object.catalog,
schema=data_object.schema_name,
identifier=data_object.name,
quote_policy=self.quote_policy,
type=dbt_relation_type,
)

def _table_to_relation(self, table: exp.Table) -> BaseRelation:
return self.relation_type.create(
database=table.catalog or None,
Expand Down
30 changes: 25 additions & 5 deletions tests/dbt/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
engine_adapter.create_table(
table_name="foo.another", columns_to_types={"col": exp.DataType.build("int")}
)
engine_adapter.create_view(
view_name="foo.bar_view", query_or_df=parse_one("select * from foo.bar")
)
engine_adapter.create_table(
table_name="ignored.ignore", columns_to_types={"col": exp.DataType.build("int")}
)
Expand All @@ -46,11 +49,24 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar') }}")
== '"memory"."foo"."bar"'
)

assert (
renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar').type }}")
== "table"
)

assert (
renderer(
"{{ adapter.get_relation(database=None, schema='foo', identifier='bar_view').type }}"
)
== "view"
)

assert renderer(
"{%- set relation = adapter.get_relation(database=None, schema='foo', identifier='bar') -%} {{ adapter.get_columns_in_relation(relation) }}"
) == str([Column.from_description(name="baz", raw_data_type="INT")])

assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "2"
assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "3"

assert renderer(
"""
Expand Down Expand Up @@ -110,26 +126,30 @@ def test_bigquery_get_columns_in_relation(
def test_normalization(
sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture
):
from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType

context = sushi_test_project.context
assert context.target
data_object = DataObject(catalog="test", schema="bla", name="bob", type=DataObjectType.TABLE)

# bla and bob will be normalized to lowercase since the target is duckdb
adapter_mock = mocker.MagicMock()
adapter_mock.default_catalog = "test"
adapter_mock.dialect = "duckdb"

adapter_mock.get_data_object.return_value = data_object
duckdb_renderer = runtime_renderer(context, engine_adapter=adapter_mock)

schema_bla = schema_("bla", "test", quoted=True)
relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True)

duckdb_renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}")
adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)])
adapter_mock.get_data_object.assert_has_calls([call(relation_bla_bob)])

# bla and bob will be normalized to uppercase since the target is Snowflake, even though the default dialect is duckdb
adapter_mock = mocker.MagicMock()
adapter_mock.default_catalog = "test"
adapter_mock.dialect = "snowflake"
adapter_mock.get_data_object.return_value = data_object
context.target = SnowflakeConfig(
account="test",
user="test",
Expand All @@ -144,10 +164,10 @@ def test_normalization(
relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True)

renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}")
adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)])
adapter_mock.get_data_object.assert_has_calls([call(relation_bla_bob)])

renderer("{{ adapter.get_relation(database='custom_db', schema='bla', identifier='bob') }}")
adapter_mock.table_exists.assert_has_calls(
adapter_mock.get_data_object.assert_has_calls(
[call(exp.table_("bob", db="bla", catalog="custom_db", quoted=True))]
)

Expand Down