Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ filterwarnings = [
]
retry_delay = 10

[tool.ruff]
line-length = 100

[tool.ruff.lint]
select = [
"F401",
Expand Down
7 changes: 7 additions & 0 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,13 @@ def create_schema(
raise
logger.warning("Failed to create schema '%s': %s", schema_name, e)

def get_bq_schema(self, table_name: TableName) -> t.List[bigquery.SchemaField]:
table = exp.to_table(table_name)
if len(table.parts) == 3 and "." in table.name:
self.execute(exp.select("*").from_(table).limit(0))
return self._query_job._query_results.schema
return self._get_table(table).schema

def columns(
self, table_name: TableName, include_pseudo_columns: bool = False
) -> t.Dict[str, exp.DataType]:
Expand Down
26 changes: 23 additions & 3 deletions sqlmesh/dbt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,31 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.Lis
return relations

def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]:
from dbt.adapters.base.column import Column

mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation)))

if self.project_dialect == "bigquery":
# dbt.adapters.bigquery.column.BigQueryColumn has a different constructor signature
# We need to use BigQueryColumn.create_from_field() to create the column instead
if (
hasattr(self.column_type, "create_from_field")
and callable(getattr(self.column_type, "create_from_field"))
and hasattr(self.engine_adapter, "get_bq_schema")
and callable(getattr(self.engine_adapter, "get_bq_schema"))
):
return [
self.column_type.create_from_field(field) # type: ignore
for field in self.engine_adapter.get_bq_schema(mapped_table) # type: ignore
]
from dbt.adapters.base.column import Column

return [
Column.from_description(
name=name, raw_data_type=dtype.sql(dialect=self.project_dialect)
)
for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items()
]
return [
Column.from_description(
self.column_type.from_description(
name=name, raw_data_type=dtype.sql(dialect=self.project_dialect)
)
for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items()
Expand Down
33 changes: 33 additions & 0 deletions tests/core/engine_adapter/integration/test_integration_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,39 @@ def test_compare_nested_values_in_table_diff(ctx: TestContext):
ctx.engine_adapter.drop_table(target_table)


def test_get_bq_schema(ctx: TestContext, engine_adapter: BigQueryEngineAdapter):
from google.cloud.bigquery import SchemaField

table = ctx.table("test")

engine_adapter.execute(f"""
CREATE TABLE {table.sql(dialect=ctx.dialect)} (
id STRING NOT NULL,
user_data STRUCT<id STRING NOT NULL, name STRING NOT NULL, address STRING>,
tags ARRAY<STRING>,
score NUMERIC,
created_at DATETIME
)
""")

bg_schema = engine_adapter.get_bq_schema(table)
assert len(bg_schema) == 5
assert bg_schema[0] == SchemaField(name="id", field_type="STRING", mode="REQUIRED")
assert bg_schema[1] == SchemaField(
name="user_data",
field_type="RECORD",
mode="NULLABLE",
fields=[
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
SchemaField(name="name", field_type="STRING", mode="REQUIRED"),
SchemaField(name="address", field_type="STRING", mode="NULLABLE"),
],
)
assert bg_schema[2] == SchemaField(name="tags", field_type="STRING", mode="REPEATED")
assert bg_schema[3] == SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE")
assert bg_schema[4] == SchemaField(name="created_at", field_type="DATETIME", mode="NULLABLE")


def test_column_types(ctx: TestContext):
model_name = ctx.table("test")
sqlmesh = ctx.create_context()
Expand Down
40 changes: 39 additions & 1 deletion tests/dbt/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sqlmesh.dbt.adapter import ParsetimeAdapter
from sqlmesh.dbt.project import Project
from sqlmesh.dbt.relation import Policy
from sqlmesh.dbt.target import SnowflakeConfig
from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import JinjaMacroRegistry

Expand Down Expand Up @@ -68,6 +68,44 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
)


def test_bigquery_get_columns_in_relation(
sushi_test_project: Project,
runtime_renderer: t.Callable,
mocker: MockerFixture,
):
from dbt.adapters.bigquery import BigQueryColumn
from google.cloud.bigquery import SchemaField

context = sushi_test_project.context
context.target = BigQueryConfig(name="test", schema="test", database="test")

adapter_mock = mocker.MagicMock()
adapter_mock.default_catalog = "test"
adapter_mock.dialect = "bigquery"
table_schema = [
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
SchemaField(
name="user_data",
field_type="RECORD",
mode="NULLABLE",
fields=[
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
SchemaField(name="name", field_type="STRING", mode="REQUIRED"),
SchemaField(name="address", field_type="STRING", mode="NULLABLE"),
],
),
SchemaField(name="tags", field_type="STRING", mode="REPEATED"),
SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE"),
SchemaField(name="created_at", field_type="TIMESTAMP", mode="NULLABLE"),
]
adapter_mock.get_bq_schema.return_value = table_schema
renderer = runtime_renderer(context, engine_adapter=adapter_mock, dialect="bigquery")
assert renderer(
"{%- set relation = api.Relation.create(database='test', schema='test', identifier='test_table') -%}"
"{{ adapter.get_columns_in_relation(relation) }}"
) == str([BigQueryColumn.create_from_field(field) for field in table_schema])


@pytest.mark.cicdonly
def test_normalization(
sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture
Expand Down