diff --git a/pyproject.toml b/pyproject.toml index 8e1c4b879f..118086d1fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -272,6 +272,9 @@ filterwarnings = [ ] retry_delay = 10 +[tool.ruff] +line-length = 100 + [tool.ruff.lint] select = [ "F401", diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index e953f4d1d0..78f2dee6aa 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -266,6 +266,13 @@ def create_schema( raise logger.warning("Failed to create schema '%s': %s", schema_name, e) + def get_bq_schema(self, table_name: TableName) -> t.List[bigquery.SchemaField]: + table = exp.to_table(table_name) + if len(table.parts) == 3 and "." in table.name: + self.execute(exp.select("*").from_(table).limit(0)) + return self._query_job._query_results.schema + return self._get_table(table).schema + def columns( self, table_name: TableName, include_pseudo_columns: bool = False ) -> t.Dict[str, exp.DataType]: diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 92719abacc..4178c960a7 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -291,11 +291,31 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.Lis return relations def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]: - from dbt.adapters.base.column import Column - mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation))) + + if self.project_dialect == "bigquery": + # dbt.adapters.bigquery.column.BigQueryColumn has a different constructor signature + # We need to use BigQueryColumn.create_from_field() to create the column instead + if ( + hasattr(self.column_type, "create_from_field") + and callable(getattr(self.column_type, "create_from_field")) + and hasattr(self.engine_adapter, "get_bq_schema") + and callable(getattr(self.engine_adapter, "get_bq_schema")) + ): + return [ + self.column_type.create_from_field(field) # type: ignore + for field in self.engine_adapter.get_bq_schema(mapped_table) # type: ignore + ] + from dbt.adapters.base.column import Column + + return [ + Column.from_description( + name=name, raw_data_type=dtype.sql(dialect=self.project_dialect) + ) + for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items() + ] return [ - Column.from_description( + self.column_type.from_description( name=name, raw_data_type=dtype.sql(dialect=self.project_dialect) ) for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items() diff --git a/tests/core/engine_adapter/integration/test_integration_bigquery.py b/tests/core/engine_adapter/integration/test_integration_bigquery.py index c97c94d036..e1cfaded13 100644 --- a/tests/core/engine_adapter/integration/test_integration_bigquery.py +++ b/tests/core/engine_adapter/integration/test_integration_bigquery.py @@ -341,6 +341,39 @@ def test_compare_nested_values_in_table_diff(ctx: TestContext): ctx.engine_adapter.drop_table(target_table) +def test_get_bq_schema(ctx: TestContext, engine_adapter: BigQueryEngineAdapter): + from google.cloud.bigquery import SchemaField + + table = ctx.table("test") + + engine_adapter.execute(f""" + CREATE TABLE {table.sql(dialect=ctx.dialect)} ( + id STRING NOT NULL, + user_data STRUCT, + tags ARRAY, + score NUMERIC, + created_at DATETIME + ) + """) + + bg_schema = engine_adapter.get_bq_schema(table) + assert len(bg_schema) == 5 + assert bg_schema[0] == SchemaField(name="id", field_type="STRING", mode="REQUIRED") + assert bg_schema[1] == SchemaField( + name="user_data", + field_type="RECORD", + mode="NULLABLE", + fields=[ + SchemaField(name="id", field_type="STRING", mode="REQUIRED"), + SchemaField(name="name", field_type="STRING", mode="REQUIRED"), + SchemaField(name="address", field_type="STRING", mode="NULLABLE"), + ], + ) + assert bg_schema[2] == SchemaField(name="tags", field_type="STRING", mode="REPEATED") + assert bg_schema[3] == SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE") + assert bg_schema[4] == SchemaField(name="created_at", field_type="DATETIME", mode="NULLABLE") + + def test_column_types(ctx: TestContext): model_name = ctx.table("test") sqlmesh = ctx.create_context() diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index 31428b953c..73a2e1f1f2 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -17,7 +17,7 @@ from sqlmesh.dbt.adapter import ParsetimeAdapter from sqlmesh.dbt.project import Project from sqlmesh.dbt.relation import Policy -from sqlmesh.dbt.target import SnowflakeConfig +from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import JinjaMacroRegistry @@ -68,6 +68,44 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla ) +def test_bigquery_get_columns_in_relation( + sushi_test_project: Project, + runtime_renderer: t.Callable, + mocker: MockerFixture, +): + from dbt.adapters.bigquery import BigQueryColumn + from google.cloud.bigquery import SchemaField + + context = sushi_test_project.context + context.target = BigQueryConfig(name="test", schema="test", database="test") + + adapter_mock = mocker.MagicMock() + adapter_mock.default_catalog = "test" + adapter_mock.dialect = "bigquery" + table_schema = [ + SchemaField(name="id", field_type="STRING", mode="REQUIRED"), + SchemaField( + name="user_data", + field_type="RECORD", + mode="NULLABLE", + fields=[ + SchemaField(name="id", field_type="STRING", mode="REQUIRED"), + SchemaField(name="name", field_type="STRING", mode="REQUIRED"), + SchemaField(name="address", field_type="STRING", mode="NULLABLE"), + ], + ), + SchemaField(name="tags", field_type="STRING", mode="REPEATED"), + SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE"), + SchemaField(name="created_at", field_type="TIMESTAMP", mode="NULLABLE"), + ] + adapter_mock.get_bq_schema.return_value = table_schema + renderer = runtime_renderer(context, engine_adapter=adapter_mock, dialect="bigquery") + assert renderer( + "{%- set relation = api.Relation.create(database='test', schema='test', identifier='test_table') -%}" + "{{ adapter.get_columns_in_relation(relation) }}" + ) == str([BigQueryColumn.create_from_field(field) for field in table_schema]) + + @pytest.mark.cicdonly def test_normalization( sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture