diff --git a/sqlmesh/dbt/seed.py b/sqlmesh/dbt/seed.py index fde5c7e569..10e98cf93c 100644 --- a/sqlmesh/dbt/seed.py +++ b/sqlmesh/dbt/seed.py @@ -58,29 +58,27 @@ def to_sqlmesh( kwargs = self.sqlmesh_model_kwargs(context) columns = kwargs.get("columns") or {} - descriptions = kwargs.get("column_descriptions") or {} - missing_types = (set(descriptions) | set(self.columns)) - set(columns) - if not columns or missing_types: - agate_table = ( - agate_helper.from_csv(seed_path, [], delimiter=self.delimiter) - if SUPPORTS_DELIMITER - else agate_helper.from_csv(seed_path, []) - ) - inferred_types = { - name: AGATE_TYPE_MAPPING[tpe.__class__] - for name, tpe in zip(agate_table.column_names, agate_table.column_types) - } - - # The columns list built from the mixture of supplied and inferred types needs to - # be in the same order as the data for assumptions elsewhere in the codebase to hold true - new_columns = {} - for column_name in agate_table.column_names: - if (column_name in missing_types) or (column_name not in columns): - new_columns[column_name] = inferred_types[column_name] - else: - new_columns[column_name] = columns[column_name] - - kwargs["columns"] = new_columns + + agate_table = ( + agate_helper.from_csv(seed_path, [], delimiter=self.delimiter) + if SUPPORTS_DELIMITER + else agate_helper.from_csv(seed_path, []) + ) + inferred_types = { + name: AGATE_TYPE_MAPPING[tpe.__class__] + for name, tpe in zip(agate_table.column_names, agate_table.column_types) + } + + # The columns list built from the mixture of supplied and inferred types needs to + # be in the same order as the data for assumptions elsewhere in the codebase to hold true + new_columns = {} + for column_name in agate_table.column_names: + if column_name not in columns: + new_columns[column_name] = inferred_types[column_name] + else: + new_columns[column_name] = columns[column_name] + + kwargs["columns"] = new_columns return create_seed_model( self.canonical_name(context), diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index e8b355e9f5..a16cc16f43 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -471,22 +471,18 @@ def test_seed_columns(): package="package", path=Path("examples/sushi_dbt/seeds/waiter_names.csv"), columns={ - "address": ColumnConfig( - name="address", data_type="text", description="Business address" - ), - "zipcode": ColumnConfig( - name="zipcode", data_type="text", description="Business zipcode" - ), + "id": ColumnConfig(name="id", data_type="text", description="The ID"), + "name": ColumnConfig(name="name", data_type="text", description="The name"), }, ) expected_column_types = { - "address": exp.DataType.build("text"), - "zipcode": exp.DataType.build("text"), + "id": exp.DataType.build("text"), + "name": exp.DataType.build("text"), } expected_column_descriptions = { - "address": "Business address", - "zipcode": "Business zipcode", + "id": "The ID", + "name": "The name", } context = DbtContext() @@ -503,21 +499,21 @@ def test_seed_column_types(): package="package", path=Path("examples/sushi_dbt/seeds/waiter_names.csv"), column_types={ - "address": "text", - "zipcode": "text", + "id": "text", + "name": "text", }, columns={ - "zipcode": ColumnConfig(name="zipcode", description="Business zipcode"), + "name": ColumnConfig(name="name", description="The name"), }, quote_columns=True, ) expected_column_types = { - "address": exp.DataType.build("text"), - "zipcode": exp.DataType.build("text"), + "id": exp.DataType.build("text"), + "name": exp.DataType.build("text"), } expected_column_descriptions = { - "zipcode": "Business zipcode", + "name": "The name", } context = DbtContext()