diff --git a/docs/integrations/dbt.md b/docs/integrations/dbt.md index c5e4bdd2d9..4342f47779 100644 --- a/docs/integrations/dbt.md +++ b/docs/integrations/dbt.md @@ -344,18 +344,16 @@ Model documentation is available in the [SQLMesh UI](../quickstart/ui.md#2-open- SQLMesh supports running dbt projects using the majority of dbt jinja methods, including: -| Method | Method | Method | Method | -| ----------- | -------------- | ------------ | ------- | -| adapter (*) | env_var | project_name | target | -| as_bool | exceptions | ref | this | -| as_native | from_yaml | return | to_yaml | -| as_number | is_incremental | run_query | var | -| as_text | load_result | schema | zip | -| api | log | set | | -| builtins | modules | source | | -| config | print | statement | | - -\* `adapter.expand_target_column_types` is not currently supported. +| Method | Method | Method | Method | +| --------- | -------------- | ------------ | ------- | +| adapter | env_var | project_name | target | +| as_bool | exceptions | ref | this | +| as_native | from_yaml | return | to_yaml | +| as_number | is_incremental | run_query | var | +| as_text | load_result | schema | zip | +| api | log | set | | +| builtins | modules | source | | +| config | print | statement | | ## Unsupported dbt jinja methods @@ -363,7 +361,6 @@ The dbt jinja methods that are not currently supported are: * debug * selected_sources -* adapter.expand_target_column_types * graph.nodes.values * graph.metrics.values diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 7de90a8ea5..2dc9890ca4 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -12,6 +12,7 @@ from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.utils import AttributeDict +from sqlmesh.core.schema_diff import TableAlterOperation if t.TYPE_CHECKING: import agate @@ -85,6 +86,12 @@ def drop_schema(self, relation: BaseRelation) -> None: def drop_relation(self, relation: BaseRelation) -> None: """Drops a relation (table) in the target database.""" + @abc.abstractmethod + def expand_target_column_types( + self, from_relation: BaseRelation, to_relation: BaseRelation + ) -> None: + """Expand to_relation's column types to match those of from_relation.""" + @abc.abstractmethod def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: """Renames a relation (table) in the target database.""" @@ -213,6 +220,11 @@ def drop_schema(self, relation: BaseRelation) -> None: def drop_relation(self, relation: BaseRelation) -> None: self._raise_parsetime_adapter_call_error("drop relation") + def expand_target_column_types( + self, from_relation: BaseRelation, to_relation: BaseRelation + ) -> None: + self._raise_parsetime_adapter_call_error("expand target column types") + def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: self._raise_parsetime_adapter_call_error("rename relation") @@ -355,6 +367,39 @@ def drop_relation(self, relation: BaseRelation) -> None: if relation.schema is not None and relation.identifier is not None: self.engine_adapter.drop_table(self._normalize(self._relation_to_table(relation))) + def expand_target_column_types( + self, from_relation: BaseRelation, to_relation: BaseRelation + ) -> None: + from_dbt_columns = {c.name: c for c in self.get_columns_in_relation(from_relation)} + to_dbt_columns = {c.name: c for c in self.get_columns_in_relation(to_relation)} + + from_table_name = self._normalize(self._relation_to_table(from_relation)) + to_table_name = self._normalize(self._relation_to_table(to_relation)) + + from_columns = self.engine_adapter.columns(from_table_name) + to_columns = self.engine_adapter.columns(to_table_name) + + current_columns = {} + new_columns = {} + for column_name, from_column in from_dbt_columns.items(): + target_column = to_dbt_columns.get(column_name) + if target_column is not None and target_column.can_expand_to(from_column): + current_columns[column_name] = to_columns[column_name] + new_columns[column_name] = from_columns[column_name] + + alter_expressions = t.cast( + t.List[TableAlterOperation], + self.engine_adapter.schema_differ.compare_columns( + to_table_name, + current_columns, + new_columns, + ignore_destructive=True, + ), + ) + + if alter_expressions: + self.engine_adapter.alter_table(alter_expressions) + def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: old_table_name = self._normalize(self._relation_to_table(from_relation)) new_table_name = self._normalize(self._relation_to_table(to_relation)) diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index 5a41d237d3..445e5f29c0 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -18,6 +18,7 @@ from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterChangeColumnTypeOperation pytestmark = pytest.mark.dbt @@ -349,3 +350,81 @@ def test_adapter_get_relation_normalization( renderer("{{ adapter.list_relations(database=None, schema='foo') }}") == '[]' ) + + +def test_adapter_expand_target_column_types( + sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture +): + from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType + + data_object_from = DataObject( + catalog="test", schema="foo", name="from_table", type=DataObjectType.TABLE + ) + data_object_to = DataObject( + catalog="test", schema="foo", name="to_table", type=DataObjectType.TABLE + ) + from_columns = { + "int_col": exp.DataType.build("int"), + "same_text_col": exp.DataType.build("varchar(1)"), # varchar(1) -> varchar(1) + "unexpandable_text_col": exp.DataType.build("varchar(2)"), # varchar(4) -> varchar(2) + "expandable_text_col1": exp.DataType.build("varchar(16)"), # varchar(8) -> varchar(16) + "expandable_text_col2": exp.DataType.build("varchar(64)"), # varchar(32) -> varchar(64) + } + to_columns = { + "int_col": exp.DataType.build("int"), + "same_text_col": exp.DataType.build("varchar(1)"), + "unexpandable_text_col": exp.DataType.build("varchar(4)"), + "expandable_text_col1": exp.DataType.build("varchar(8)"), + "expandable_text_col2": exp.DataType.build("varchar(32)"), + } + adapter_mock = mocker.MagicMock() + adapter_mock.default_catalog = "test" + adapter_mock.get_data_object.side_effect = [data_object_from, data_object_to] + # columns() is called 4 times, twice by adapter.get_columns_in_relation() and twice by the engine_adapter + adapter_mock.columns.side_effect = [ + from_columns, + to_columns, + from_columns, + to_columns, + ] + adapter_mock.schema_differ = SchemaDiffer() + + context = sushi_test_project.context + renderer = runtime_renderer(context, engine_adapter=adapter_mock) + + renderer(""" + {%- set from_relation = adapter.get_relation( + database=None, + schema='foo', + identifier='from_table') -%} + + {% set to_relation = adapter.get_relation( + database=None, + schema='foo', + identifier='to_table') -%} + + {% do adapter.expand_target_column_types(from_relation, to_relation) %} + """) + adapter_mock.get_data_object.assert_has_calls( + [ + call(exp.to_table('"test"."foo"."from_table"')), + call(exp.to_table('"test"."foo"."to_table"')), + ] + ) + assert len(adapter_mock.alter_table.call_args.args) == 1 + alter_expressions = adapter_mock.alter_table.call_args.args[0] + assert len(alter_expressions) == 2 + alter_operation1 = alter_expressions[0] + assert isinstance(alter_operation1, TableAlterChangeColumnTypeOperation) + assert alter_operation1.expression == parse_one( + """ALTER TABLE "test"."foo"."to_table" + ALTER COLUMN expandable_text_col1 + SET DATA TYPE VARCHAR(16)""" + ) + alter_operation2 = alter_expressions[1] + assert isinstance(alter_operation2, TableAlterChangeColumnTypeOperation) + assert alter_operation2.expression == parse_one( + """ALTER TABLE "test"."foo"."to_table" + ALTER COLUMN expandable_text_col2 + SET DATA TYPE VARCHAR(64)""" + )