|
| 1 | +import typing as t |
| 2 | +import pytest |
| 3 | +from sqlglot import exp |
| 4 | +from pytest import FixtureRequest |
| 5 | +from sqlmesh.core.engine_adapter import RisingwaveEngineAdapter |
| 6 | +from tests.core.engine_adapter.integration import ( |
| 7 | + TestContext, |
| 8 | + generate_pytest_params, |
| 9 | + ENGINES_BY_NAME, |
| 10 | + IntegrationTestEngine, |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["risingwave"]))) |
| 15 | +def ctx( |
| 16 | + request: FixtureRequest, |
| 17 | + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable], |
| 18 | +) -> t.Iterable[TestContext]: |
| 19 | + yield from create_test_context(*request.param) |
| 20 | + |
| 21 | + |
| 22 | +@pytest.fixture |
| 23 | +def engine_adapter(ctx: TestContext) -> RisingwaveEngineAdapter: |
| 24 | + assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter) |
| 25 | + return ctx.engine_adapter |
| 26 | + |
| 27 | + |
| 28 | +@pytest.fixture |
| 29 | +def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str, exp.DataType]: |
| 30 | + base_types = { |
| 31 | + "smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False), |
| 32 | + "int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False), |
| 33 | + "bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False), |
| 34 | + "ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False), |
| 35 | + "tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False), |
| 36 | + "vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR, nested=False), |
| 37 | + } |
| 38 | + # generate all arrays of base types |
| 39 | + arr_types = { |
| 40 | + f"{type_name}_arr_col": exp.DataType.build( |
| 41 | + exp.DataType.Type.ARRAY, |
| 42 | + expressions=[base_type], |
| 43 | + nested=True, |
| 44 | + ) |
| 45 | + for type_name, base_type in base_types.items() |
| 46 | + } |
| 47 | + # generate struct with all base types as nested columns |
| 48 | + struct_types = { |
| 49 | + "struct_col": exp.DataType.build( |
| 50 | + exp.DataType.Type.STRUCT, |
| 51 | + expressions=[ |
| 52 | + exp.ColumnDef( |
| 53 | + this=exp.Identifier(this=f"nested_{type_name}_col", quoted=False), |
| 54 | + kind=base_type, |
| 55 | + ) |
| 56 | + for type_name, base_type in base_types.items() |
| 57 | + ], |
| 58 | + nested=True, |
| 59 | + ) |
| 60 | + } |
| 61 | + return {**base_types, **arr_types, **struct_types} |
| 62 | + |
| 63 | + |
| 64 | +def test_engine_adapter(ctx: TestContext): |
| 65 | + assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter) |
| 66 | + assert ctx.engine_adapter.fetchone("select 1") == (1,) |
| 67 | + |
| 68 | + |
| 69 | +def test_engine_adapter_columns( |
| 70 | + ctx: TestContext, risingwave_columns_with_datatypes: t.Dict[str, exp.DataType] |
| 71 | +): |
| 72 | + table = ctx.table("TEST_COLUMNS") |
| 73 | + query = exp.select( |
| 74 | + *[ |
| 75 | + exp.cast(exp.null(), dtype).as_(name) |
| 76 | + for name, dtype in risingwave_columns_with_datatypes.items() |
| 77 | + ] |
| 78 | + ) |
| 79 | + ctx.engine_adapter.ctas(table, query) |
| 80 | + |
| 81 | + column_result = ctx.engine_adapter.columns(table) |
| 82 | + assert column_result == risingwave_columns_with_datatypes |
0 commit comments