Skip to content

Commit baa2583

Browse files
committed
Feat(dbt): Add support for adapter.expand_target_column_types
1 parent fe64851 commit baa2583

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError
1313
from sqlmesh.utils.jinja import JinjaMacroRegistry
1414
from sqlmesh.utils import AttributeDict
15+
from sqlmesh.core.schema_diff import TableAlterOperation
1516

1617
if t.TYPE_CHECKING:
1718
import agate
@@ -85,6 +86,12 @@ def drop_schema(self, relation: BaseRelation) -> None:
8586
def drop_relation(self, relation: BaseRelation) -> None:
8687
"""Drops a relation (table) in the target database."""
8788

89+
@abc.abstractmethod
90+
def expand_target_column_types(
91+
self, from_relation: BaseRelation, to_relation: BaseRelation
92+
) -> None:
93+
"""Expand to_relation's column types to match those of from_relation."""
94+
8895
@abc.abstractmethod
8996
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
9097
"""Renames a relation (table) in the target database."""
@@ -213,6 +220,11 @@ def drop_schema(self, relation: BaseRelation) -> None:
213220
def drop_relation(self, relation: BaseRelation) -> None:
214221
self._raise_parsetime_adapter_call_error("drop relation")
215222

223+
def expand_target_column_types(
224+
self, from_relation: BaseRelation, to_relation: BaseRelation
225+
) -> None:
226+
self._raise_parsetime_adapter_call_error("expand target column types")
227+
216228
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
217229
self._raise_parsetime_adapter_call_error("rename relation")
218230

@@ -355,6 +367,39 @@ def drop_relation(self, relation: BaseRelation) -> None:
355367
if relation.schema is not None and relation.identifier is not None:
356368
self.engine_adapter.drop_table(self._normalize(self._relation_to_table(relation)))
357369

370+
def expand_target_column_types(
371+
self, from_relation: BaseRelation, to_relation: BaseRelation
372+
) -> None:
373+
from_dbt_columns = {c.name: c for c in self.get_columns_in_relation(from_relation)}
374+
to_dbt_columns = {c.name: c for c in self.get_columns_in_relation(to_relation)}
375+
376+
from_table_name = self._normalize(self._relation_to_table(from_relation))
377+
to_table_name = self._normalize(self._relation_to_table(to_relation))
378+
379+
from_columns = self.engine_adapter.columns(from_table_name)
380+
to_columns = self.engine_adapter.columns(to_table_name)
381+
382+
current_columns = {}
383+
new_columns = {}
384+
for column_name, from_column in from_dbt_columns.items():
385+
target_column = to_dbt_columns.get(column_name)
386+
if target_column is not None and target_column.can_expand_to(from_column):
387+
current_columns[column_name] = to_columns[column_name]
388+
new_columns[column_name] = from_columns[column_name]
389+
390+
alter_expressions = t.cast(
391+
t.List[TableAlterOperation],
392+
self.engine_adapter.SCHEMA_DIFFER.compare_columns(
393+
to_table_name,
394+
current_columns,
395+
new_columns,
396+
ignore_destructive=True,
397+
),
398+
)
399+
400+
if alter_expressions:
401+
self.engine_adapter.alter_table(alter_expressions)
402+
358403
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
359404
old_table_name = self._normalize(self._relation_to_table(from_relation))
360405
new_table_name = self._normalize(self._relation_to_table(to_relation))

tests/dbt/test_adapter.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig
1919
from sqlmesh.utils.errors import ConfigError
2020
from sqlmesh.utils.jinja import JinjaMacroRegistry
21+
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterChangeColumnTypeOperation
2122

2223
pytestmark = pytest.mark.dbt
2324

@@ -349,3 +350,70 @@ def test_adapter_get_relation_normalization(
349350
renderer("{{ adapter.list_relations(database=None, schema='foo') }}")
350351
== '[<SnowflakeRelation "memory"."FOO"."BAR">]'
351352
)
353+
354+
355+
def test_adapter_expand_target_column_types(
356+
sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture
357+
):
358+
from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType
359+
360+
data_object_from = DataObject(
361+
catalog="test", schema="foo", name="from_table", type=DataObjectType.TABLE
362+
)
363+
data_object_to = DataObject(
364+
catalog="test", schema="foo", name="to_table", type=DataObjectType.TABLE
365+
)
366+
from_columns = {
367+
"int_col": exp.DataType.build("int"),
368+
"same_text_col": exp.DataType.build("varchar(1)"), # varchar(1) -> varchar(1)
369+
"unexpandable_text_col": exp.DataType.build("varchar(2)"), # varchar(4) -> varchar(2)
370+
"expandable_text_col": exp.DataType.build("varchar(16)"), # varchar(8) -> varchar(16)
371+
}
372+
to_columns = {
373+
"int_col": exp.DataType.build("int"),
374+
"same_text_col": exp.DataType.build("varchar(1)"),
375+
"unexpandable_text_col": exp.DataType.build("varchar(4)"),
376+
"expandable_text_col": exp.DataType.build("varchar(8)"),
377+
}
378+
adapter_mock = mocker.MagicMock()
379+
adapter_mock.default_catalog = "test"
380+
adapter_mock.get_data_object.side_effect = [data_object_from, data_object_to]
381+
# columns() is called 4 times, twice by adapter.get_columns_in_relation() and twice by the engine_adapter
382+
adapter_mock.columns.side_effect = [
383+
from_columns,
384+
to_columns,
385+
] * 2
386+
adapter_mock.SCHEMA_DIFFER = SchemaDiffer()
387+
388+
context = sushi_test_project.context
389+
renderer = runtime_renderer(context, engine_adapter=adapter_mock)
390+
391+
renderer("""
392+
{%- set from_relation = adapter.get_relation(
393+
database=None,
394+
schema='foo',
395+
identifier='from_table') -%}
396+
397+
{% set to_relation = adapter.get_relation(
398+
database=None,
399+
schema='foo',
400+
identifier='to_table') -%}
401+
402+
{% do adapter.expand_target_column_types(from_relation, to_relation) %}
403+
""")
404+
adapter_mock.get_data_object.assert_has_calls(
405+
[
406+
call(exp.to_table('"test"."foo"."from_table"')),
407+
call(exp.to_table('"test"."foo"."to_table"')),
408+
]
409+
)
410+
assert len(adapter_mock.alter_table.call_args.args) == 1
411+
alter_expressions = adapter_mock.alter_table.call_args.args[0]
412+
assert len(alter_expressions) == 1
413+
alter_operation = alter_expressions[0]
414+
assert isinstance(alter_operation, TableAlterChangeColumnTypeOperation)
415+
assert alter_operation.expression == parse_one(
416+
"""ALTER TABLE "test"."foo"."to_table"
417+
ALTER COLUMN expandable_text_col
418+
SET DATA TYPE VARCHAR(16)"""
419+
)

0 commit comments

Comments
 (0)