Skip to content

Commit e3daa8f

Browse files
Fix!: Ensure correct datatypes are fetched for RisingWave dialect (#4903)
Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
1 parent f018216 commit e3daa8f

File tree

3 files changed

+153
-1
lines changed

3 files changed

+153
-1
lines changed

sqlmesh/core/engine_adapter/risingwave.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CommentCreationTable,
1515
)
1616

17+
from sqlmesh.utils.errors import SQLMeshError
1718

1819
if t.TYPE_CHECKING:
1920
from sqlmesh.core._typing import TableName
@@ -32,5 +33,37 @@ class RisingwaveEngineAdapter(PostgresEngineAdapter):
3233
SUPPORTS_TRANSACTIONS = False
3334
MAX_IDENTIFIER_LENGTH = None
3435

36+
def columns(
37+
self, table_name: TableName, include_pseudo_columns: bool = False
38+
) -> t.Dict[str, exp.DataType]:
39+
"""Fetches column names and types for the target_table"""
40+
table = exp.to_table(table_name)
41+
42+
sql = (
43+
exp.select("rw_columns.name AS column_name", "rw_columns.data_type AS data_type")
44+
.from_("rw_catalog.rw_columns")
45+
.join("rw_catalog.rw_relations", on="rw_relations.id=rw_columns.relation_id")
46+
.join("rw_catalog.rw_schemas", on="rw_schemas.id=rw_relations.schema_id")
47+
.where(
48+
exp.and_(
49+
exp.column("name", table="rw_relations").eq(table.alias_or_name),
50+
exp.column("name", table="rw_columns").neq("_row_id"),
51+
exp.column("name", table="rw_columns").neq("_rw_timestamp"),
52+
)
53+
)
54+
)
55+
56+
if table.db:
57+
sql = sql.where(exp.column("name", table="rw_schemas").eq(table.db))
58+
59+
self.execute(sql)
60+
resp = self.cursor.fetchall()
61+
if not resp:
62+
raise SQLMeshError(f"Could not get columns for table {table_name}. Table not found.")
63+
return {
64+
column_name: exp.DataType.build(data_type, dialect=self.dialect, udt=True)
65+
for column_name, data_type in resp
66+
}
67+
3568
def _truncate_table(self, table_name: TableName) -> None:
3669
return self.execute(exp.Delete(this=exp.to_table(table_name)))
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import typing as t
2+
import pytest
3+
from sqlglot import exp
4+
from pytest import FixtureRequest
5+
from sqlmesh.core.engine_adapter import RisingwaveEngineAdapter
6+
from tests.core.engine_adapter.integration import (
7+
TestContext,
8+
generate_pytest_params,
9+
ENGINES_BY_NAME,
10+
IntegrationTestEngine,
11+
)
12+
13+
14+
@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["risingwave"])))
15+
def ctx(
16+
request: FixtureRequest,
17+
create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable],
18+
) -> t.Iterable[TestContext]:
19+
yield from create_test_context(*request.param)
20+
21+
22+
@pytest.fixture
23+
def engine_adapter(ctx: TestContext) -> RisingwaveEngineAdapter:
24+
assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter)
25+
return ctx.engine_adapter
26+
27+
28+
@pytest.fixture
29+
def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str, exp.DataType]:
30+
base_types = {
31+
"smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False),
32+
"int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False),
33+
"bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False),
34+
"ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False),
35+
"tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False),
36+
"vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR, nested=False),
37+
}
38+
# generate all arrays of base types
39+
arr_types = {
40+
f"{type_name}_arr_col": exp.DataType.build(
41+
exp.DataType.Type.ARRAY,
42+
expressions=[base_type],
43+
nested=True,
44+
)
45+
for type_name, base_type in base_types.items()
46+
}
47+
# generate struct with all base types as nested columns
48+
struct_types = {
49+
"struct_col": exp.DataType.build(
50+
exp.DataType.Type.STRUCT,
51+
expressions=[
52+
exp.ColumnDef(
53+
this=exp.Identifier(this=f"nested_{type_name}_col", quoted=False),
54+
kind=base_type,
55+
)
56+
for type_name, base_type in base_types.items()
57+
],
58+
nested=True,
59+
)
60+
}
61+
return {**base_types, **arr_types, **struct_types}
62+
63+
64+
def test_engine_adapter(ctx: TestContext):
65+
assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter)
66+
assert ctx.engine_adapter.fetchone("select 1") == (1,)
67+
68+
69+
def test_engine_adapter_columns(
70+
ctx: TestContext, risingwave_columns_with_datatypes: t.Dict[str, exp.DataType]
71+
):
72+
table = ctx.table("TEST_COLUMNS")
73+
query = exp.select(
74+
*[
75+
exp.cast(exp.null(), dtype).as_(name)
76+
for name, dtype in risingwave_columns_with_datatypes.items()
77+
]
78+
)
79+
ctx.engine_adapter.ctas(table, query)
80+
81+
column_result = ctx.engine_adapter.columns(table)
82+
assert column_result == risingwave_columns_with_datatypes

tests/core/engine_adapter/test_risingwave.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from unittest.mock import call
44

55
import pytest
6-
from sqlglot import parse_one
6+
from sqlglot import parse_one, exp
77
from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter
88

99
pytestmark = [pytest.mark.engine, pytest.mark.risingwave]
@@ -15,6 +15,43 @@ def adapter(make_mocked_engine_adapter):
1515
return adapter
1616

1717

18+
def test_columns(adapter: t.Callable):
19+
adapter.cursor.fetchall.return_value = [
20+
("smallint_col", "smallint"),
21+
("int_col", "integer"),
22+
("bigint_col", "bigint"),
23+
("ts_col", "timestamp without time zone"),
24+
("tstz_col", "timestamp with time zone"),
25+
("int_array_col", "integer[]"),
26+
("vchar_col", "character varying"),
27+
("struct_col", "struct<nested_col integer>"),
28+
]
29+
resp = adapter.columns("db.table")
30+
assert resp == {
31+
"smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False),
32+
"int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False),
33+
"bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False),
34+
"ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False),
35+
"tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False),
36+
"int_array_col": exp.DataType.build(
37+
exp.DataType.Type.ARRAY,
38+
expressions=[exp.DataType.build(exp.DataType.Type.INT, nested=False)],
39+
nested=True,
40+
),
41+
"vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR),
42+
"struct_col": exp.DataType.build(
43+
exp.DataType.Type.STRUCT,
44+
expressions=[
45+
exp.ColumnDef(
46+
this=exp.Identifier(this="nested_col", quoted=False),
47+
kind=exp.DataType.build(exp.DataType.Type.INT, nested=False),
48+
)
49+
],
50+
nested=True,
51+
),
52+
}
53+
54+
1855
def test_create_view(adapter: t.Callable):
1956
adapter.create_view("db.view", parse_one("SELECT 1"), replace=True)
2057
adapter.create_view("db.view", parse_one("SELECT 1"), replace=False)

0 commit comments

Comments
 (0)