Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions docs/integrations/dbt.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,26 +344,23 @@ Model documentation is available in the [SQLMesh UI](../quickstart/ui.md#2-open-

SQLMesh supports running dbt projects using the majority of dbt jinja methods, including:

| Method | Method | Method | Method |
| ----------- | -------------- | ------------ | ------- |
| adapter (*) | env_var | project_name | target |
| as_bool | exceptions | ref | this |
| as_native | from_yaml | return | to_yaml |
| as_number | is_incremental | run_query | var |
| as_text | load_result | schema | zip |
| api | log | set | |
| builtins | modules | source | |
| config | print | statement | |

\* `adapter.expand_target_column_types` is not currently supported.
| Method | Method | Method | Method |
| --------- | -------------- | ------------ | ------- |
| adapter | env_var | project_name | target |
| as_bool | exceptions | ref | this |
| as_native | from_yaml | return | to_yaml |
| as_number | is_incremental | run_query | var |
| as_text | load_result | schema | zip |
| api | log | set | |
| builtins | modules | source | |
| config | print | statement | |

## Unsupported dbt jinja methods

The dbt jinja methods that are not currently supported are:

* debug
* selected_sources
* adapter.expand_target_column_types
* graph.nodes.values
* graph.metrics.values

Expand Down
45 changes: 45 additions & 0 deletions sqlmesh/dbt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError
from sqlmesh.utils.jinja import JinjaMacroRegistry
from sqlmesh.utils import AttributeDict
from sqlmesh.core.schema_diff import TableAlterOperation

if t.TYPE_CHECKING:
import agate
Expand Down Expand Up @@ -85,6 +86,12 @@ def drop_schema(self, relation: BaseRelation) -> None:
def drop_relation(self, relation: BaseRelation) -> None:
"""Drops a relation (table) in the target database."""

@abc.abstractmethod
def expand_target_column_types(
self, from_relation: BaseRelation, to_relation: BaseRelation
) -> None:
"""Expand to_relation's column types to match those of from_relation."""

@abc.abstractmethod
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
"""Renames a relation (table) in the target database."""
Expand Down Expand Up @@ -213,6 +220,11 @@ def drop_schema(self, relation: BaseRelation) -> None:
def drop_relation(self, relation: BaseRelation) -> None:
self._raise_parsetime_adapter_call_error("drop relation")

def expand_target_column_types(
self, from_relation: BaseRelation, to_relation: BaseRelation
) -> None:
self._raise_parsetime_adapter_call_error("expand target column types")

def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
self._raise_parsetime_adapter_call_error("rename relation")

Expand Down Expand Up @@ -355,6 +367,39 @@ def drop_relation(self, relation: BaseRelation) -> None:
if relation.schema is not None and relation.identifier is not None:
self.engine_adapter.drop_table(self._normalize(self._relation_to_table(relation)))

def expand_target_column_types(
self, from_relation: BaseRelation, to_relation: BaseRelation
) -> None:
from_dbt_columns = {c.name: c for c in self.get_columns_in_relation(from_relation)}
to_dbt_columns = {c.name: c for c in self.get_columns_in_relation(to_relation)}

from_table_name = self._normalize(self._relation_to_table(from_relation))
to_table_name = self._normalize(self._relation_to_table(to_relation))

from_columns = self.engine_adapter.columns(from_table_name)
to_columns = self.engine_adapter.columns(to_table_name)

current_columns = {}
new_columns = {}
for column_name, from_column in from_dbt_columns.items():
target_column = to_dbt_columns.get(column_name)
if target_column is not None and target_column.can_expand_to(from_column):
current_columns[column_name] = to_columns[column_name]
new_columns[column_name] = from_columns[column_name]

alter_expressions = t.cast(
t.List[TableAlterOperation],
self.engine_adapter.schema_differ.compare_columns(
to_table_name,
current_columns,
new_columns,
ignore_destructive=True,
),
)

if alter_expressions:
self.engine_adapter.alter_table(alter_expressions)

def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
old_table_name = self._normalize(self._relation_to_table(from_relation))
new_table_name = self._normalize(self._relation_to_table(to_relation))
Expand Down
79 changes: 79 additions & 0 deletions tests/dbt/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import JinjaMacroRegistry
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterChangeColumnTypeOperation

pytestmark = pytest.mark.dbt

Expand Down Expand Up @@ -349,3 +350,81 @@ def test_adapter_get_relation_normalization(
renderer("{{ adapter.list_relations(database=None, schema='foo') }}")
== '[<SnowflakeRelation "memory"."FOO"."BAR">]'
)


def test_adapter_expand_target_column_types(
sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture
):
from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType

data_object_from = DataObject(
catalog="test", schema="foo", name="from_table", type=DataObjectType.TABLE
)
data_object_to = DataObject(
catalog="test", schema="foo", name="to_table", type=DataObjectType.TABLE
)
from_columns = {
"int_col": exp.DataType.build("int"),
"same_text_col": exp.DataType.build("varchar(1)"), # varchar(1) -> varchar(1)
"unexpandable_text_col": exp.DataType.build("varchar(2)"), # varchar(4) -> varchar(2)
"expandable_text_col1": exp.DataType.build("varchar(16)"), # varchar(8) -> varchar(16)
"expandable_text_col2": exp.DataType.build("varchar(64)"), # varchar(32) -> varchar(64)
}
to_columns = {
"int_col": exp.DataType.build("int"),
"same_text_col": exp.DataType.build("varchar(1)"),
"unexpandable_text_col": exp.DataType.build("varchar(4)"),
"expandable_text_col1": exp.DataType.build("varchar(8)"),
"expandable_text_col2": exp.DataType.build("varchar(32)"),
}
adapter_mock = mocker.MagicMock()
adapter_mock.default_catalog = "test"
adapter_mock.get_data_object.side_effect = [data_object_from, data_object_to]
# columns() is called 4 times, twice by adapter.get_columns_in_relation() and twice by the engine_adapter
adapter_mock.columns.side_effect = [
from_columns,
to_columns,
from_columns,
to_columns,
]
adapter_mock.schema_differ = SchemaDiffer()

context = sushi_test_project.context
renderer = runtime_renderer(context, engine_adapter=adapter_mock)

renderer("""
{%- set from_relation = adapter.get_relation(
database=None,
schema='foo',
identifier='from_table') -%}

{% set to_relation = adapter.get_relation(
database=None,
schema='foo',
identifier='to_table') -%}

{% do adapter.expand_target_column_types(from_relation, to_relation) %}
""")
adapter_mock.get_data_object.assert_has_calls(
[
call(exp.to_table('"test"."foo"."from_table"')),
call(exp.to_table('"test"."foo"."to_table"')),
]
)
assert len(adapter_mock.alter_table.call_args.args) == 1
alter_expressions = adapter_mock.alter_table.call_args.args[0]
assert len(alter_expressions) == 2
alter_operation1 = alter_expressions[0]
assert isinstance(alter_operation1, TableAlterChangeColumnTypeOperation)
assert alter_operation1.expression == parse_one(
"""ALTER TABLE "test"."foo"."to_table"
ALTER COLUMN expandable_text_col1
SET DATA TYPE VARCHAR(16)"""
)
alter_operation2 = alter_expressions[1]
assert isinstance(alter_operation2, TableAlterChangeColumnTypeOperation)
assert alter_operation2.expression == parse_one(
"""ALTER TABLE "test"."foo"."to_table"
ALTER COLUMN expandable_text_col2
SET DATA TYPE VARCHAR(64)"""
)