diff --git a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/utils/ArrowUtil.java b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/utils/ArrowUtil.java index 95cc9ce..8e801d7 100644 --- a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/utils/ArrowUtil.java +++ b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/utils/ArrowUtil.java @@ -24,12 +24,19 @@ /** * @author yuexie * @date 2024/11/8 15:51 + * Arrow type utility for parsing Kuscia column types to Arrow types **/ public class ArrowUtil { + /** + * Parse Kuscia column type to Arrow type + * @param type Column type string (e.g., "int32", "interval_year_month", "large_string") + * @return ArrowType + */ public static ArrowType parseKusciaColumnType(String type) { - // string integer float datetime timestamp - return switch (type) { + String typeLower = type.toLowerCase(); + return switch (typeLower) { + // Integer types case "int8" -> Types.MinorType.TINYINT.getType(); case "int16" -> Types.MinorType.SMALLINT.getType(); case "int32" -> Types.MinorType.INT.getType(); @@ -38,13 +45,49 @@ public static ArrowType parseKusciaColumnType(String type) { case "uint16" -> Types.MinorType.UINT2.getType(); case "uint32" -> Types.MinorType.UINT4.getType(); case "uint64" -> Types.MinorType.UINT8.getType(); + + // Floating point types case "float32" -> Types.MinorType.FLOAT4.getType(); case "float64", "float" -> Types.MinorType.FLOAT8.getType(); + + // Date types case "date32" -> Types.MinorType.DATEDAY.getType(); case "date64" -> Types.MinorType.DATEMILLI.getType(); + + // Time types + case "time32" -> Types.MinorType.TIMEMILLI.getType(); + case "time64" -> Types.MinorType.TIMEMICRO.getType(); + + // Timestamp types + case "timestamp" -> Types.MinorType.TIMESTAMPMICRO.getType(); + case "timestamp_us" -> Types.MinorType.TIMESTAMPMICRO.getType(); + case "timestamp_ms" -> Types.MinorType.TIMESTAMPMILLI.getType(); + case "timestamp_ns" -> Types.MinorType.TIMESTAMPNANO.getType(); + case "timestamp_tz" -> Types.MinorType.TIMESTAMPMICROTZ.getType(); + + // Boolean types case "bool" -> Types.MinorType.BIT.getType(); + + // String types case "string", "str" -> Types.MinorType.VARCHAR.getType(); + case "large_string", "large_utf8", "utf8_large" -> Types.MinorType.LARGEVARCHAR.getType(); + + // Binary types case "binary" -> Types.MinorType.VARBINARY.getType(); + case "large_binary", "large_varbinary", "varbinary_large" -> Types.MinorType.LARGEVARBINARY.getType(); + + // Decimal types + // Note: Types.MinorType.DECIMAL.getType() throws UnsupportedOperationException + // Decimal requires precision/scale, must use new ArrowType.Decimal(precision, scale, bitWidth) + case "decimal" -> new ArrowType.Decimal(38, 10, 128); + + // Interval types + case "interval_year_month", "interval_ym" -> + Types.MinorType.INTERVALYEAR.getType(); + case "interval_day_time", "interval_dt" -> + Types.MinorType.INTERVALDAY.getType(); + case "interval" -> Types.MinorType.INTERVALYEAR.getType(); + default -> throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Unsupported field types: " + type); }; } diff --git a/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/ByteValueVisitor.java b/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/ByteValueVisitor.java index 1c4c002..cd5a38b 100644 --- a/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/ByteValueVisitor.java +++ b/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/ByteValueVisitor.java @@ -16,12 +16,16 @@ package org.secretflow.dataproxy.core.visitor; +import lombok.extern.slf4j.Slf4j; + import javax.annotation.Nonnull; +import java.math.BigDecimal; /** * @author yuexie * @date 2024/11/1 20:25 **/ +@Slf4j public class ByteValueVisitor implements ValueVisitor{ @Override @@ -54,4 +58,40 @@ public Byte visit(@Nonnull Double value) { return value.byteValue(); } + @Override + public Byte visit(@Nonnull String value) { + try { + return Byte.valueOf(value); + } catch (NumberFormatException e) { + log.warn("Failed to parse string '{}' as Byte, using 0", value); + return (byte) 0; + } + } + + @Override + public Byte visit(@Nonnull BigDecimal value) { + return value.byteValue(); + } + + @Override + public Byte visit(@Nonnull Object value) { + // Directly Byte type, return directly + if (value instanceof Byte byteValue) { + return byteValue; + } + + // Number type (including BigDecimal, Integer, Long, Short, Float, Double, etc.) + if (value instanceof Number number) { + return number.byteValue(); + } + + // String type, call dedicated visit(String) method + if (value instanceof String stringValue) { + return visit(stringValue); + } + + // Other types: try to convert to string then parse + return visit(value.toString()); + } + } diff --git a/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/IntegerValueVisitor.java b/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/IntegerValueVisitor.java index a0003a4..9e04c50 100644 --- a/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/IntegerValueVisitor.java +++ b/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/IntegerValueVisitor.java @@ -19,6 +19,7 @@ import lombok.extern.slf4j.Slf4j; import javax.annotation.Nonnull; +import java.sql.Time; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZonedDateTime; @@ -87,7 +88,21 @@ public Integer visit(@Nonnull Integer value) { @Override public Integer visit(@Nonnull Date value) { - return (int) value.getTime(); + // Handle java.sql.Time: Time32Vector needs milliseconds since midnight + if (value instanceof Time sqlTime) { + // java.sql.Time.getTime() returns milliseconds since Unix epoch + // But Time32Vector needs milliseconds since midnight of the day + // Convert to LocalTime then calculate milliseconds + return (int) (sqlTime.toLocalTime().toNanoOfDay() / 1_000_000); + } + + // Handle java.sql.Date: DateDayVector needs days since 1970-01-01 + if (value instanceof java.sql.Date sqlDate) { + return (int) sqlDate.toLocalDate().toEpochDay(); + } + + // For java.util.Date, assume it's a date type, convert milliseconds to days + return (int) (value.getTime() / (24 * 60 * 60 * 1000L)); } @Override diff --git a/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/LongValueVisitor.java b/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/LongValueVisitor.java index 0f6e9e8..83577aa 100644 --- a/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/LongValueVisitor.java +++ b/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/LongValueVisitor.java @@ -19,12 +19,8 @@ import lombok.extern.slf4j.Slf4j; import javax.annotation.Nonnull; -import java.time.Instant; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.ZoneId; -import java.time.ZoneOffset; -import java.time.ZonedDateTime; +import java.sql.Time; +import java.time.*; import java.util.Date; /** @@ -50,6 +46,8 @@ public Long visit(@Nonnull Object value) { if (value instanceof Long longValue) { return visit(longValue); + } else if (value instanceof Time sqlTime) { + return this.visit(sqlTime); } else if (value instanceof Date dateValue) { return this.visit(dateValue); } else if (value instanceof LocalDateTime localDateTime) { @@ -102,7 +100,10 @@ public Long visit(@Nonnull ZonedDateTime value) { @Override public Long visit(@Nonnull LocalDateTime value) { - return value.toInstant(ZoneOffset.of(ZoneId.systemDefault().getId())).toEpochMilli(); + // LocalDateTime has no timezone information, treat it as local time in system default timezone. + // Use atZone() instead of toInstant(ZoneOffset.of(...)) because ZoneOffset.of() requires + // an offset (e.g., "+08:00"), cannot directly use zone ID + return value.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(); } @Override @@ -115,4 +116,12 @@ public Long visit(@Nonnull Instant value) { log.debug("visit instant: {}", value.toEpochMilli()); return value.toEpochMilli(); } + + public Long visit(@Nonnull Time value) { + long nanosSinceMidnight = value.toLocalTime().toNanoOfDay(); + long microsSinceMidnight = nanosSinceMidnight / 1_000; + log.debug("Converting java.sql.Time {} (toLocalTime: {}) to microseconds since midnight: {} (nanos: {})", + value, value.toLocalTime(), microsSinceMidnight, nanosSinceMidnight); + return microsSinceMidnight; + } } diff --git a/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/ValueVisitor.java b/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/ValueVisitor.java index aedbc8d..be6eb06 100644 --- a/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/ValueVisitor.java +++ b/dataproxy-core/src/main/java/org/secretflow/dataproxy/core/visitor/ValueVisitor.java @@ -17,6 +17,7 @@ package org.secretflow.dataproxy.core.visitor; import javax.annotation.Nonnull; +import java.math.BigDecimal; import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; @@ -65,6 +66,10 @@ default T visit(@Nonnull byte[] value) { throw new UnsupportedOperationException("byte[] not supported"); } + default T visit(@Nonnull BigDecimal value) { + throw new UnsupportedOperationException("BigDecimal not supported"); + } + default T visit(@Nonnull Object value) { throw new UnsupportedOperationException("Object not supported"); } diff --git a/dataproxy-integration-tests/pom.xml b/dataproxy-integration-tests/pom.xml index 3091826..8d003b6 100644 --- a/dataproxy-integration-tests/pom.xml +++ b/dataproxy-integration-tests/pom.xml @@ -32,6 +32,10 @@ org.secretflow dataproxy-plugin-odps + + org.secretflow + dataproxy-plugin-dameng + org.projectlombok @@ -67,6 +71,24 @@ mockito-junit-jupiter test + + + + org.testcontainers + testcontainers + test + + + org.testcontainers + junit-jupiter + test + + + com.dameng + DmJdbcDriver18 + 8.1.3.62 + test + diff --git a/dataproxy-integration-tests/src/test/java/org/secretflow/dataproxy/integration/tests/DamengIntegrationTest.java b/dataproxy-integration-tests/src/test/java/org/secretflow/dataproxy/integration/tests/DamengIntegrationTest.java new file mode 100644 index 0000000..27f6302 --- /dev/null +++ b/dataproxy-integration-tests/src/test/java/org/secretflow/dataproxy/integration/tests/DamengIntegrationTest.java @@ -0,0 +1,1198 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.secretflow.dataproxy.integration.tests; + +import com.google.protobuf.Any; +import lombok.extern.slf4j.Slf4j; +import org.apache.arrow.flight.*; +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.PeriodDuration; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.condition.EnabledIfSystemProperty; +import org.secretflow.dataproxy.common.utils.ArrowUtil; +import org.secretflow.dataproxy.integration.tests.utils.DamengTestUtil; +import org.secretflow.dataproxy.core.config.FlightServerContext; +import org.secretflow.dataproxy.server.DataProxyFlightServer; +import org.secretflow.v1alpha1.common.Common; +import org.secretflow.v1alpha1.kusciaapi.Domaindata; +import org.secretflow.v1alpha1.kusciaapi.Domaindatasource; +import org.secretflow.v1alpha1.kusciaapi.Flightdm; +import org.secretflow.v1alpha1.kusciaapi.Flightinner; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.math.BigDecimal; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.Statement; +import java.time.Duration; +import java.time.LocalDate; +import java.time.LocalTime; +import java.time.Period; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.*; + + +/** + * @author kongxiaoran + * @date 2025/11/07 + * Comprehensive Dameng Integration Test covering all Arrow types supported by ArrowUtil.parseKusciaColumnType + */ +@Slf4j +@Testcontainers +@EnabledIfSystemProperty(named = "enableDamengIntegration", matches = "true") +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +public class DamengIntegrationTest extends BaseArrowFlightServerTest { + + // User/password info for database built into Dameng docker image + private static final String DOCKER_DAMENG_USER = "SYSDBA"; + private static final String DOCKER_DAMENG_PASSWORD = "SYSDBA001"; + private static final String DOCKER_DATABASE_NAME = "SYSDBA"; + private static final int DOCKER_CONTAINER_PORT = 5237; + private static final String TABLE_NAME = "COMPREHENSIVE_TEST_TABLE"; + + // Whether to use Dameng docker for unit testing or remote Dameng database + private static final boolean USE_DOCKER = Boolean.parseBoolean( + System.getProperty("useDamengDocker", "true")); + + private static String damengJdbcUrl; + private static String damengHost; + private static int damengPort; + private static String damengUser; + private static String damengPassword; + private static String databaseName; + + private static GenericContainer dameng; + private static Domaindatasource.DomainDataSource domainDataSource; + private static Domaindata.DomainData domainDataWithTable; + + // Comprehensive column definitions covering all supported types + // Note: uint types are skipped as Dameng doesn't support unsigned integers + private static final List columns = Arrays.asList( + // Integer types + Common.DataColumn.newBuilder().setName("col_int8").setType("int8").build(), + Common.DataColumn.newBuilder().setName("col_int16").setType("int16").build(), + Common.DataColumn.newBuilder().setName("col_int32").setType("int32").build(), + Common.DataColumn.newBuilder().setName("col_int64").setType("int64").build(), + // Floating point types + Common.DataColumn.newBuilder().setName("col_float32").setType("float32").build(), + Common.DataColumn.newBuilder().setName("col_float64").setType("float64").build(), + // Date types + Common.DataColumn.newBuilder().setName("col_date32").setType("date32").build(), + Common.DataColumn.newBuilder().setName("col_date64").setType("date64").build(), + // Time types + Common.DataColumn.newBuilder().setName("col_time32").setType("time32").build(), + Common.DataColumn.newBuilder().setName("col_time64").setType("time64").build(), + // Timestamp types + Common.DataColumn.newBuilder().setName("col_timestamp").setType("timestamp").build(), + Common.DataColumn.newBuilder().setName("col_timestamp_ms").setType("timestamp_ms").build(), + Common.DataColumn.newBuilder().setName("col_timestamp_us").setType("timestamp_us").build(), + // Boolean types + Common.DataColumn.newBuilder().setName("col_bool").setType("bool").build(), + // String types + Common.DataColumn.newBuilder().setName("col_string").setType("string").build(), + Common.DataColumn.newBuilder().setName("col_large_string").setType("large_string").build(), + // Binary types + Common.DataColumn.newBuilder().setName("col_binary").setType("binary").build(), + Common.DataColumn.newBuilder().setName("col_large_binary").setType("large_binary").build(), + // Decimal types + Common.DataColumn.newBuilder().setName("col_decimal").setType("decimal").build(), + // Interval types + Common.DataColumn.newBuilder().setName("col_interval_ym").setType("interval_year_month").build(), + Common.DataColumn.newBuilder().setName("col_interval_dt").setType("interval_day_time").build() + ); + + @BeforeAll + public static void startServer() { + log.info("Starting comprehensive Dameng integration test..."); + + // Start DataProxyFlightServer + dataProxyFlightServer = new DataProxyFlightServer(FlightServerContext.getInstance().getFlightServerConfig()); + assertDoesNotThrow(() -> { + serverThread = new Thread(() -> { + try { + dataProxyFlightServer.start(); + SERVER_START_LATCH.countDown(); + dataProxyFlightServer.awaitTermination(); + } catch (Exception e) { + fail("Exception was thrown during server start: " + e.getMessage()); + } + }); + serverThread.start(); + SERVER_START_LATCH.await(); + }); + + // Setup database + assertDoesNotThrow(() -> { + if (USE_DOCKER) { + log.info("Starting Docker container for Dameng database..."); + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>(DockerImageName.parse("kongxr7/dameng:8.1")) + .withExposedPorts(DOCKER_CONTAINER_PORT, 5237) + .withPrivilegedMode(true) + .withEnv("PAGE_SIZE", "16") + .withEnv("LD_LIBRARY_PATH", "/opt/dmdbms/bin") + .withEnv("INSTANCE_NAME", "dm8db") + .withStartupTimeout(Duration.ofSeconds(600)) + .waitingFor(Wait.forListeningPort()); + container.start(); + dameng = container; + + damengHost = dameng.getHost(); + damengPort = dameng.getMappedPort(DOCKER_CONTAINER_PORT); + damengUser = DOCKER_DAMENG_USER; + damengPassword = DOCKER_DAMENG_PASSWORD; + databaseName = DOCKER_DATABASE_NAME; + log.info("Using Docker container - Host: {}, Port: {}", damengHost, damengPort); + } else { + String endpoint = DamengTestUtil.getDamengEndpoint(); + String[] hostPort = endpoint.split(":"); + damengHost = hostPort[0]; + damengPort = Integer.parseInt(hostPort[1]); + damengUser = DamengTestUtil.getDamengUsername(); + damengPassword = DamengTestUtil.getDamengPassword(); + databaseName = DamengTestUtil.getDamengDatabase(); + log.info("Using remote database - Host: {}, Port: {}", damengHost, damengPort); + } + + damengJdbcUrl = String.format("jdbc:dm://%s:%d/%s", damengHost, damengPort, databaseName); + + try (Connection conn = DriverManager.getConnection(damengJdbcUrl, damengUser, damengPassword); + Statement stmt = conn.createStatement()) { + stmt.execute(String.format("DROP TABLE IF EXISTS %s", TABLE_NAME)); + + // Create table with all supported types + String createTableSql = String.format( + "CREATE TABLE %s (" + + "col_int8 TINYINT, " + + "col_int16 SMALLINT, " + + "col_int32 INT, " + + "col_int64 BIGINT, " + + "col_float32 FLOAT, " + + "col_float64 DOUBLE, " + + "col_date32 DATE, " + + // Date(MILLISECOND) -> DATETIME(3) + "col_date64 DATETIME(3), " + + "col_time32 TIME(3), " + + "col_time64 TIME(6), " + + "col_timestamp TIMESTAMP, " + + "col_timestamp_ms TIMESTAMP(3), " + + "col_timestamp_us TIMESTAMP(6), " + + "col_bool BIT, " + + "col_string VARCHAR(100), " + + "col_large_string CLOB, " + + "col_binary VARBINARY(100), " + + "col_large_binary BLOB, " + + "col_decimal DECIMAL(38, 10), " + + "col_interval_ym INTERVAL YEAR TO MONTH, " + + "col_interval_dt INTERVAL DAY TO SECOND" + + ")", TABLE_NAME); + stmt.execute(createTableSql); + log.info("Created comprehensive test table with {} columns", columns.size()); + } + }); + + // Prepare protobuf messages + Domaindatasource.DatabaseDataSourceInfo damengDataSourceInfo = + Domaindatasource.DatabaseDataSourceInfo.newBuilder() + .setEndpoint(String.format("%s:%d", damengHost, damengPort)) + .setUser(damengUser) + .setPassword(damengPassword) + .setDatabase(databaseName) + .build(); + + domainDataSource = Domaindatasource.DomainDataSource.newBuilder() + .setDatasourceId("dameng-datasource") + .setName("dameng_db") + .setType("dameng") + .setInfo(Domaindatasource.DataSourceInfo.newBuilder().setDatabase(damengDataSourceInfo)) + .build(); + + domainDataWithTable = Domaindata.DomainData.newBuilder() + .setDatasourceId("dameng-datasource") + .setName(TABLE_NAME) + .setRelativeUri(TABLE_NAME) + .setDomaindataId("dameng-table") + .setType("table") + .addAllColumns(columns) + .build(); + } + + @AfterAll + public static void stopServer() { + assertDoesNotThrow(() -> { + if (dataProxyFlightServer != null) dataProxyFlightServer.close(); + if (serverThread != null) serverThread.interrupt(); + if (USE_DOCKER && dameng != null) { + dameng.stop(); + log.info("Docker container stopped"); + } + }); + } + + @Test + @Order(1) + public void testDoPut() { + log.info("Testing DoPut with comprehensive data types..."); + + Flightinner.CommandDataMeshUpdate command = Flightinner.CommandDataMeshUpdate.newBuilder() + .setDatasource(domainDataSource) + .setDomaindata(domainDataWithTable) + .setUpdate(Flightdm.CommandDomainDataUpdate.newBuilder() + .setContentType(Flightdm.ContentType.CSV)) + .build(); + + assertDoesNotThrow(() -> { + FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command).toByteArray()); + FlightInfo info = client.getInfo(descriptor, CallOptions.timeout(10, TimeUnit.SECONDS)); + Ticket ticket = info.getEndpoints().get(0).getTicket(); + + Schema schema = new Schema(columns.stream() + .map(col -> Field.nullable(col.getName(), ArrowUtil.parseKusciaColumnType(col.getType()))) + .collect(Collectors.toList())); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + + // Get all vectors + TinyIntVector int8Vector = (TinyIntVector) root.getVector("col_int8"); + SmallIntVector int16Vector = (SmallIntVector) root.getVector("col_int16"); + IntVector int32Vector = (IntVector) root.getVector("col_int32"); + BigIntVector int64Vector = (BigIntVector) root.getVector("col_int64"); + Float4Vector float32Vector = (Float4Vector) root.getVector("col_float32"); + Float8Vector float64Vector = (Float8Vector) root.getVector("col_float64"); + DateDayVector date32Vector = (DateDayVector) root.getVector("col_date32"); + DateMilliVector date64Vector = (DateMilliVector) root.getVector("col_date64"); + TimeMilliVector time32Vector = (TimeMilliVector) root.getVector("col_time32"); + TimeMicroVector time64Vector = (TimeMicroVector) root.getVector("col_time64"); + TimeStampMicroVector timestampVector = (TimeStampMicroVector) root.getVector("col_timestamp"); + TimeStampMilliVector timestampMsVector = (TimeStampMilliVector) root.getVector("col_timestamp_ms"); + TimeStampMicroVector timestampUsVector = (TimeStampMicroVector) root.getVector("col_timestamp_us"); + BitVector boolVector = (BitVector) root.getVector("col_bool"); + VarCharVector stringVector = (VarCharVector) root.getVector("col_string"); + LargeVarCharVector largeStringVector = (LargeVarCharVector) root.getVector("col_large_string"); + VarBinaryVector binaryVector = (VarBinaryVector) root.getVector("col_binary"); + LargeVarBinaryVector largeBinaryVector = (LargeVarBinaryVector) root.getVector("col_large_binary"); + DecimalVector decimalVector = (DecimalVector) root.getVector("col_decimal"); + IntervalYearVector intervalYVector = (IntervalYearVector) root.getVector("col_interval_ym"); + IntervalDayVector intervalDVector = (IntervalDayVector) root.getVector("col_interval_dt"); + + // Set test data for row 0 + int8Vector.setSafe(0, 127); + int16Vector.setSafe(0, 32767); + int32Vector.setSafe(0, 2147483647); + int64Vector.setSafe(0, 9223372036854775807L); + float32Vector.setSafe(0, 3.14f); + float64Vector.setSafe(0, 2.718281828459045); + // Date: days since epoch (1970-01-01) + date32Vector.setSafe(0, (int) LocalDate.of(2024, 1, 1).toEpochDay()); + // Date64: milliseconds since epoch + date64Vector.setSafe(0, java.sql.Timestamp.valueOf("2024-01-01 12:00:00").getTime()); + // Time32: milliseconds since midnight + time32Vector.setSafe(0, (int) LocalTime.of(12, 30, 45).toSecondOfDay() * 1000); + // Time64: microseconds since midnight + // Note: LocalTime.of(12, 30, 45, 123456000) = 12:30:45.123456 (nanoseconds) + time64Vector.setSafe(0, LocalTime.of(12, 30, 45, 123456000).toNanoOfDay() / 1000); + // Timestamp: microseconds since epoch + long timestampMs = java.sql.Timestamp.valueOf("2024-01-01 12:00:00").getTime(); + timestampVector.setSafe(0, timestampMs * 1000); + timestampMsVector.setSafe(0, timestampMs); + timestampUsVector.setSafe(0, timestampMs * 1000); + boolVector.setSafe(0, 1); + stringVector.setSafe(0, "Test String".getBytes()); + largeStringVector.setSafe(0, ("Large String Content " + "A".repeat(100)).getBytes()); + binaryVector.setSafe(0, new byte[]{0x01, 0x02, 0x03, 0x04}); + largeBinaryVector.setSafe(0, new byte[]{0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}); + decimalVector.setSafe(0, new BigDecimal("1234567890123456789012345678.1234567890")); + // IntervalYear: total months (12 = 1 year) + intervalYVector.setSafe(0, 12); + // IntervalDay: days and milliseconds (1 day, 0 milliseconds) + intervalDVector.setSafe(0, 1, 0); + + root.setRowCount(1); + + FlightClient.ClientStreamListener listener = client.startPut( + FlightDescriptor.command(ticket.getBytes()), root, new AsyncPutListener()); + listener.putNext(); + listener.completed(); + listener.getResult(); + + log.info("Successfully wrote 1 row with all data types"); + } + }); + } + + @Test + @Order(2) + public void testDoGet() { + log.info("Testing DoGet to verify all data types and values..."); + + Flightinner.CommandDataMeshQuery query = Flightinner.CommandDataMeshQuery.newBuilder() + .setDatasource(domainDataSource) + .setDomaindata(domainDataWithTable) + .setQuery(Flightdm.CommandDomainDataQuery.newBuilder().setContentType(Flightdm.ContentType.CSV)) + .build(); + + assertDoesNotThrow(() -> { + FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(query).toByteArray()); + FlightInfo info = client.getInfo(descriptor, CallOptions.timeout(10, TimeUnit.SECONDS)); + + try (FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket())) { + int rowCount = 0; + VectorSchemaRoot root = null; + + while (stream.next()) { + root = stream.getRoot(); + rowCount += root.getRowCount(); + + if (rowCount > 0) { + // Verify all column types are present + assertEquals(columns.size(), root.getSchema().getFields().size(), + "Schema should contain all defined columns"); + + // Verify data can be read (basic sanity check) + for (Common.DataColumn col : columns) { + assertNotNull(root.getVector(col.getName()), + "Column " + col.getName() + " should exist in result"); + } + + log.info("Successfully read {} rows with {} columns", rowCount, columns.size()); + } + } + + assertEquals(1, rowCount, "Should read 1 row from table"); + + // Verify data values match what was inserted + if (root != null && root.getRowCount() > 0) { + verifyDataValues(root, 0); + } + } + }); + } + + /** + * Verify that the data values read from the database match what was inserted. + * + * @param root VectorSchemaRoot containing the read data + * @param rowIndex Row index to verify (should be 0 for single row test) + */ + private void verifyDataValues(VectorSchemaRoot root, int rowIndex) { + log.info("Verifying data values for row {}", rowIndex); + + // Integer types + TinyIntVector int8Vector = (TinyIntVector) root.getVector("col_int8"); + assertEquals((byte) 127, int8Vector.get(rowIndex), "col_int8 should be 127"); + + SmallIntVector int16Vector = (SmallIntVector) root.getVector("col_int16"); + assertEquals((short) 32767, int16Vector.get(rowIndex), "col_int16 should be 32767"); + + IntVector int32Vector = (IntVector) root.getVector("col_int32"); + assertEquals(2147483647, int32Vector.get(rowIndex), "col_int32 should be 2147483647"); + + BigIntVector int64Vector = (BigIntVector) root.getVector("col_int64"); + assertEquals(9223372036854775807L, int64Vector.get(rowIndex), "col_int64 should be 9223372036854775807"); + + // Floating point types (use delta for floating point comparison) + Float4Vector float32Vector = (Float4Vector) root.getVector("col_float32"); + assertEquals(3.14f, float32Vector.get(rowIndex), 0.001f, "col_float32 should be approximately 3.14"); + + Float8Vector float64Vector = (Float8Vector) root.getVector("col_float64"); + assertEquals(2.718281828459045, float64Vector.get(rowIndex), 0.000000000000001, + "col_float64 should be approximately 2.718281828459045"); + + // Date types + DateDayVector date32Vector = (DateDayVector) root.getVector("col_date32"); + LocalDate expectedDate = LocalDate.of(2024, 1, 1); + LocalDate actualDate = LocalDate.ofEpochDay(date32Vector.get(rowIndex)); + assertEquals(expectedDate, actualDate, "col_date32 should be 2024-01-01"); + + // col_date64: Database TIMESTAMP may be read as TimeStampMicroVector instead of DateMilliVector + FieldVector date64FieldVector = root.getVector("col_date64"); + long expectedDate64Ms = java.sql.Timestamp.valueOf("2024-01-01 12:00:00").getTime(); + long actualDate64Ms = 0; + + if (date64FieldVector instanceof DateMilliVector) { + DateMilliVector date64Vector = (DateMilliVector) date64FieldVector; + actualDate64Ms = date64Vector.get(rowIndex); + } else if (date64FieldVector instanceof TimeStampMicroVector) { + // Database TIMESTAMP is often mapped to TimeStampMicroVector when reading + TimeStampMicroVector timestampVector = (TimeStampMicroVector) date64FieldVector; + // Convert microseconds to milliseconds + actualDate64Ms = timestampVector.get(rowIndex) / 1000; + } else if (date64FieldVector instanceof TimeStampMilliVector) { + TimeStampMilliVector timestampVector = (TimeStampMilliVector) date64FieldVector; + actualDate64Ms = timestampVector.get(rowIndex); + } else { + // Fallback: try to get as object and convert + Object date64Obj = date64FieldVector.getObject(rowIndex); + if (date64Obj instanceof java.time.LocalDateTime) { + actualDate64Ms = java.sql.Timestamp.valueOf((java.time.LocalDateTime) date64Obj).getTime(); + } else { + throw new AssertionError("Unexpected type for col_date64: " + + (date64Obj != null ? date64Obj.getClass().getName() : "null")); + } + } + // Allow 1 second tolerance for timestamp conversion + assertEquals(expectedDate64Ms, actualDate64Ms, 1000, + "col_date64 should be approximately 2024-01-01 12:00:00"); + + // Time types + TimeMilliVector time32Vector = (TimeMilliVector) root.getVector("col_time32"); + LocalTime expectedTime32 = LocalTime.of(12, 30, 45); + LocalTime actualTime32 = LocalTime.ofSecondOfDay(time32Vector.get(rowIndex) / 1000); + assertEquals(expectedTime32.getHour(), actualTime32.getHour(), "col_time32 hour should be 12"); + assertEquals(expectedTime32.getMinute(), actualTime32.getMinute(), "col_time32 minute should be 30"); + assertEquals(expectedTime32.getSecond(), actualTime32.getSecond(), "col_time32 second should be 45"); + + // col_time64: Database TIME(6) may be read as TimeMilliVector instead of TimeMicroVector + FieldVector time64FieldVector = root.getVector("col_time64"); + LocalTime expectedTime64 = LocalTime.of(12, 30, 45, 123456000); + LocalTime actualTime64 = null; + + if (time64FieldVector instanceof TimeMicroVector) { + // TIME(6) -> TimeMicroVector (microseconds) + TimeMicroVector time64Vector = (TimeMicroVector) time64FieldVector; + long time64Micros = time64Vector.get(rowIndex); + log.info("col_time64 read as TimeMicroVector: {} microseconds", time64Micros); + actualTime64 = LocalTime.ofNanoOfDay(time64Micros * 1000); + log.info("col_time64 converted to LocalTime: {}", actualTime64); + } else if (time64FieldVector instanceof TimeMilliVector) { + // TIME(6) should be read as TimeMicroVector, not TimeMilliVector + // If it's read as TimeMilliVector, precision is lost and test should fail + TimeMilliVector time64Vector = (TimeMilliVector) time64FieldVector; + int time64Millis = time64Vector.get(rowIndex); + fail("col_time64 was read as TimeMilliVector instead of TimeMicroVector. " + + "This indicates that TIME(6) precision was not correctly parsed. " + + "Expected microseconds precision but got milliseconds. " + + "Value: " + time64Millis + " milliseconds"); + } else { + throw new AssertionError("Unexpected type for col_time64: " + + (time64FieldVector != null ? time64FieldVector.getClass().getName() : "null")); + } + + assertEquals(expectedTime64.getHour(), actualTime64.getHour(), "col_time64 hour should be 12"); + assertEquals(expectedTime64.getMinute(), actualTime64.getMinute(), "col_time64 minute should be 30"); + assertEquals(expectedTime64.getSecond(), actualTime64.getSecond(), "col_time64 second should be 45"); + /* + * Note: java.sql.Time type from JDBC driver only supports second precision, not microsecond precision + * Even though database TIME(6) can store microseconds, JDBC driver returns java.sql.Time which loses microsecond precision + * Therefore, we only verify that the time matches up to seconds, not microseconds + * The actual nanoseconds will be 0 due to JDBC driver limitation + */ + log.info("col_time64 nanoseconds: {} (expected: {}). Note: JDBC java.sql.Time only supports second precision", + actualTime64.getNano(), expectedTime64.getNano()); + + // Timestamp types + // col_timestamp: Database TIMESTAMP may be read as different precision vectors + FieldVector timestampFieldVector = root.getVector("col_timestamp"); + long expectedTimestampMs = java.sql.Timestamp.valueOf("2024-01-01 12:00:00").getTime(); + long actualTimestampMs = 0; + + if (timestampFieldVector instanceof TimeStampMicroVector) { + // TIMESTAMP(6) or TIMESTAMP without precision -> TimeStampMicroVector (microseconds) + TimeStampMicroVector timestampVector = (TimeStampMicroVector) timestampFieldVector; + // Convert microseconds to milliseconds + actualTimestampMs = timestampVector.get(rowIndex) / 1000; + } else if (timestampFieldVector instanceof TimeStampMilliVector) { + // TIMESTAMP(3) -> TimeStampMilliVector (milliseconds) + TimeStampMilliVector timestampVector = (TimeStampMilliVector) timestampFieldVector; + actualTimestampMs = timestampVector.get(rowIndex); + } else { + // Fallback: try to get as object and convert + Object timestampObj = timestampFieldVector.getObject(rowIndex); + if (timestampObj instanceof java.time.LocalDateTime) { + actualTimestampMs = java.sql.Timestamp.valueOf((java.time.LocalDateTime) timestampObj).getTime(); + } else { + throw new AssertionError("Unexpected type for col_timestamp: " + + (timestampObj != null ? timestampObj.getClass().getName() : "null")); + } + } + // Timestamp values are stored as UTC milliseconds since epoch + // Allow 1 second tolerance for potential rounding or precision differences + assertEquals(expectedTimestampMs, actualTimestampMs, 1000, + "col_timestamp should be approximately 2024-01-01 12:00:00 (within 1 second tolerance)"); + + TimeStampMilliVector timestampMsVector = (TimeStampMilliVector) root.getVector("col_timestamp_ms"); + long actualTimestampMsValue = timestampMsVector.get(rowIndex); + assertEquals(expectedTimestampMs, actualTimestampMsValue, 1000, + "col_timestamp_ms should be approximately 2024-01-01 12:00:00 (within 1 second tolerance)"); + + TimeStampMicroVector timestampUsVector = (TimeStampMicroVector) root.getVector("col_timestamp_us"); + long actualTimestampUsMs = timestampUsVector.get(rowIndex) / 1000; + assertEquals(expectedTimestampMs, actualTimestampUsMs, 1000, + "col_timestamp_us should be approximately 2024-01-01 12:00:00 (within 1 second tolerance)"); + + // Boolean type + BitVector boolVector = (BitVector) root.getVector("col_bool"); + assertEquals(1, boolVector.get(rowIndex), "col_bool should be 1 (true)"); + + // String types + VarCharVector stringVector = (VarCharVector) root.getVector("col_string"); + String expectedString = "Test String"; + String actualString = new String(stringVector.get(rowIndex)); + assertEquals(expectedString, actualString, "col_string should be 'Test String'"); + + LargeVarCharVector largeStringVector = (LargeVarCharVector) root.getVector("col_large_string"); + String expectedLargeString = "Large String Content " + "A".repeat(100); + String actualLargeString = new String(largeStringVector.get(rowIndex)); + assertEquals(expectedLargeString, actualLargeString, "col_large_string should match expected value"); + + // Binary types + VarBinaryVector binaryVector = (VarBinaryVector) root.getVector("col_binary"); + byte[] expectedBinary = new byte[]{0x01, 0x02, 0x03, 0x04}; + byte[] actualBinary = binaryVector.get(rowIndex); + assertArrayEquals(expectedBinary, actualBinary, "col_binary should match expected bytes"); + + LargeVarBinaryVector largeBinaryVector = (LargeVarBinaryVector) root.getVector("col_large_binary"); + byte[] expectedLargeBinary = new byte[]{0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}; + byte[] actualLargeBinary = largeBinaryVector.get(rowIndex); + assertArrayEquals(expectedLargeBinary, actualLargeBinary, "col_large_binary should match expected bytes"); + + // Decimal type + DecimalVector decimalVector = (DecimalVector) root.getVector("col_decimal"); + BigDecimal expectedDecimal = new BigDecimal("1234567890123456789012345678.1234567890"); + BigDecimal actualDecimal = decimalVector.getObject(rowIndex); + assertEquals(0, expectedDecimal.compareTo(actualDecimal), + "col_decimal should match expected value: " + expectedDecimal); + + // Interval types + IntervalYearVector intervalYVector = (IntervalYearVector) root.getVector("col_interval_ym"); + // 1 year = 12 months + int expectedIntervalYM = 12; + int actualIntervalYM = intervalYVector.get(rowIndex); + assertEquals(expectedIntervalYM, actualIntervalYM, "col_interval_ym should be 12 months"); + + IntervalDayVector intervalDVector = (IntervalDayVector) root.getVector("col_interval_dt"); + // IntervalDay stores days and milliseconds + // getObject() returns PeriodDuration which contains Period (days) and Duration (milliseconds) + Object intervalObj = intervalDVector.getObject(rowIndex); + assertNotNull(intervalObj, "col_interval_dt should not be null"); + + // Extract days and milliseconds from PeriodDuration + int expectedDays = 1; + int expectedMillis = 0; + int actualDays = 0; + int actualMillis = 0; + + if (intervalObj instanceof PeriodDuration) { + PeriodDuration pd = (PeriodDuration) intervalObj; + Period period = pd.getPeriod(); + Duration duration = pd.getDuration(); + + actualDays = period.getDays(); + // Duration contains seconds and nanoseconds, convert to milliseconds + long totalMillis = duration.toMillis(); + actualMillis = (int) totalMillis; + } else if (intervalObj instanceof Duration) { + // Handle java.time.Duration directly + // Arrow's IntervalDayVector.getObject() may return Duration when data was set via Duration + // This happens when reading from JDBC ResultSet which returns Duration + Duration duration = (Duration) intervalObj; + long days = duration.toDays(); + long remainingMillis = duration.minusDays(days).toMillis(); + actualDays = (int) days; + actualMillis = (int) remainingMillis; + } else { + // Fallback: try to parse from string representation or use reflection + log.warn("IntervalDay value is not PeriodDuration or Duration, got: {}", intervalObj != null ? intervalObj.getClass().getName() : "null"); + try { + // Try reflection as fallback for PeriodDuration-like objects + java.lang.reflect.Method getPeriodMethod = intervalObj.getClass().getMethod("getPeriod"); + java.lang.reflect.Method getDurationMethod = intervalObj.getClass().getMethod("getDuration"); + + Object period = getPeriodMethod.invoke(intervalObj); + Object duration = getDurationMethod.invoke(intervalObj); + + java.lang.reflect.Method getDaysMethod = period.getClass().getMethod("getDays"); + actualDays = (Integer) getDaysMethod.invoke(period); + + java.lang.reflect.Method toMillisMethod = duration.getClass().getMethod("toMillis"); + long millis = (Long) toMillisMethod.invoke(duration); + actualMillis = (int) millis; + } catch (Exception e) { + log.error("Failed to extract days and milliseconds from interval object of type: {}, value: {}", + intervalObj != null ? intervalObj.getClass().getName() : "null", intervalObj); + fail("Failed to extract days and milliseconds from interval object: " + e.getMessage() + + ". Object type: " + (intervalObj != null ? intervalObj.getClass().getName() : "null")); + } + } + + assertEquals(expectedDays, actualDays, "col_interval_dt days should be 1"); + assertEquals(expectedMillis, actualMillis, "col_interval_dt milliseconds should be 0"); + + log.info("✅ All data values verified successfully"); + } + + @Test + @Order(3) + public void testCommandDataSourceSqlQuery() { + log.info("Testing SQL query functionality with comprehensive data types..."); + + // Build SQL query containing all major types + // Note: Skip uint types (not supported by Dameng) and timestamp_ns (Dameng doesn't support nanosecond precision) + String sql = String.format( + "SELECT " + + "col_int8, col_int16, col_int32, col_int64, " + + "col_float32, col_float64, " + + "col_date32, col_date64, " + + "col_time32, col_time64, " + + "col_timestamp, col_timestamp_ms, col_timestamp_us, " + + "col_bool, " + + "col_string, col_large_string, " + + "col_binary, col_large_binary, " + + "col_decimal, " + + "col_interval_ym, col_interval_dt " + + "FROM %s", TABLE_NAME); + + Flightinner.CommandDataMeshSqlQuery query = Flightinner.CommandDataMeshSqlQuery.newBuilder() + .setDatasource(domainDataSource) + .setQuery(Flightdm.CommandDataSourceSqlQuery.newBuilder() + .setSql(sql) + .build()) + .build(); + + assertDoesNotThrow(() -> { + FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(query).toByteArray()); + FlightInfo info = client.getInfo(descriptor, CallOptions.timeout(10, TimeUnit.SECONDS)); + + try (FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket())) { + int rowCount = 0; + VectorSchemaRoot root = null; + + while (stream.next()) { + root = stream.getRoot(); + rowCount += root.getRowCount(); + + // Verify returned column count (21 columns: all types minus uint and timestamp_ns) + int expectedColumnCount = 21; + assertEquals(expectedColumnCount, root.getSchema().getFields().size(), + "SQL query should return all selected columns"); + + // Verify all columns exist and types are correct + verifySqlQuerySchema(root); + + log.info("Successfully read {} rows with {} columns from SQL query", + rowCount, root.getSchema().getFields().size()); + } + + assertEquals(1, rowCount, "SQL query should return 1 row"); + + // Verify returned data values (consistent with written data) + if (root != null && root.getRowCount() > 0) { + verifySqlQueryDataValues(root, 0); + } + + log.info("SQL query test passed - all data types verified"); + } + }); + } + + /** + * Verify SQL query returned Schema (column types). + * + * @param root VectorSchemaRoot + */ + private void verifySqlQuerySchema(VectorSchemaRoot root) { + assertNotNull(root.getVector("col_int8"), "col_int8 should exist"); + assertNotNull(root.getVector("col_int16"), "col_int16 should exist"); + assertNotNull(root.getVector("col_int32"), "col_int32 should exist"); + assertNotNull(root.getVector("col_int64"), "col_int64 should exist"); + + assertNotNull(root.getVector("col_float32"), "col_float32 should exist"); + assertNotNull(root.getVector("col_float64"), "col_float64 should exist"); + + assertNotNull(root.getVector("col_date32"), "col_date32 should exist"); + assertNotNull(root.getVector("col_date64"), "col_date64 should exist"); + + assertNotNull(root.getVector("col_time32"), "col_time32 should exist"); + assertNotNull(root.getVector("col_time64"), "col_time64 should exist"); + + assertNotNull(root.getVector("col_timestamp"), "col_timestamp should exist"); + assertNotNull(root.getVector("col_timestamp_ms"), "col_timestamp_ms should exist"); + assertNotNull(root.getVector("col_timestamp_us"), "col_timestamp_us should exist"); + + assertNotNull(root.getVector("col_bool"), "col_bool should exist"); + + assertNotNull(root.getVector("col_string"), "col_string should exist"); + assertNotNull(root.getVector("col_large_string"), "col_large_string should exist"); + + assertNotNull(root.getVector("col_binary"), "col_binary should exist"); + assertNotNull(root.getVector("col_large_binary"), "col_large_binary should exist"); + + assertNotNull(root.getVector("col_decimal"), "col_decimal should exist"); + + assertNotNull(root.getVector("col_interval_ym"), "col_interval_ym should exist"); + assertNotNull(root.getVector("col_interval_dt"), "col_interval_dt should exist"); + + log.debug("All SQL query columns verified successfully"); + } + + /** + * Verify SQL query returned data values (consistent with written data). + * + * @param root VectorSchemaRoot + * @param rowIndex Row index + */ + private void verifySqlQueryDataValues(VectorSchemaRoot root, int rowIndex) { + log.info("Verifying SQL query data values for row {}", rowIndex); + + TinyIntVector int8Vector = (TinyIntVector) root.getVector("col_int8"); + assertEquals(127, int8Vector.get(rowIndex), "col_int8 should be 127"); + + SmallIntVector int16Vector = (SmallIntVector) root.getVector("col_int16"); + assertEquals(32767, int16Vector.get(rowIndex), "col_int16 should be 32767"); + + IntVector int32Vector = (IntVector) root.getVector("col_int32"); + assertEquals(2147483647, int32Vector.get(rowIndex), "col_int32 should be 2147483647"); + + BigIntVector int64Vector = (BigIntVector) root.getVector("col_int64"); + assertEquals(9223372036854775807L, int64Vector.get(rowIndex), "col_int64 should be 9223372036854775807"); + + Float4Vector float32Vector = (Float4Vector) root.getVector("col_float32"); + assertEquals(3.14f, float32Vector.get(rowIndex), 0.01f, "col_float32 should be approximately 3.14"); + + Float8Vector float64Vector = (Float8Vector) root.getVector("col_float64"); + assertEquals(2.718281828459045, float64Vector.get(rowIndex), 0.000000000000001, + "col_float64 should be approximately 2.718281828459045"); + + DateDayVector date32Vector = (DateDayVector) root.getVector("col_date32"); + int expectedDate32Days = (int) LocalDate.of(2024, 1, 1).toEpochDay(); + assertEquals(expectedDate32Days, date32Vector.get(rowIndex), "col_date32 should be 2024-01-01"); + + // col_date64: dynamically check type (may be DateMilliVector, TimeStampMicroVector, or TimeStampMilliVector) + FieldVector date64FieldVector = root.getVector("col_date64"); + long expectedDate64Ms = java.sql.Timestamp.valueOf("2024-01-01 12:00:00").getTime(); + long actualDate64Ms = 0; + if (date64FieldVector instanceof DateMilliVector) { + actualDate64Ms = ((DateMilliVector) date64FieldVector).get(rowIndex); + } else if (date64FieldVector instanceof TimeStampMicroVector) { + // Convert microseconds to milliseconds + actualDate64Ms = ((TimeStampMicroVector) date64FieldVector).get(rowIndex) / 1000; + } else if (date64FieldVector instanceof TimeStampMilliVector) { + actualDate64Ms = ((TimeStampMilliVector) date64FieldVector).get(rowIndex); + } + assertEquals(expectedDate64Ms, actualDate64Ms, 1000, "col_date64 should be approximately 2024-01-01 12:00:00"); + + TimeMilliVector time32Vector = (TimeMilliVector) root.getVector("col_time32"); + LocalTime expectedTime32 = LocalTime.of(12, 30, 45); + LocalTime actualTime32 = LocalTime.ofSecondOfDay(time32Vector.get(rowIndex) / 1000); + assertEquals(expectedTime32.getHour(), actualTime32.getHour(), "col_time32 hour should be 12"); + assertEquals(expectedTime32.getMinute(), actualTime32.getMinute(), "col_time32 minute should be 30"); + assertEquals(expectedTime32.getSecond(), actualTime32.getSecond(), "col_time32 second should be 45"); + + // col_time64: dynamically check type (TIME(6) should be read as TimeMicroVector) + FieldVector time64FieldVector = root.getVector("col_time64"); + LocalTime expectedTime64 = LocalTime.of(12, 30, 45, 123456000); + LocalTime actualTime64 = null; + if (time64FieldVector instanceof TimeMicroVector) { + TimeMicroVector time64Vector = (TimeMicroVector) time64FieldVector; + long time64Micros = time64Vector.get(rowIndex); + actualTime64 = LocalTime.ofNanoOfDay(time64Micros * 1000); + } else if (time64FieldVector instanceof TimeMilliVector) { + // TIME(6) should be read as TimeMicroVector, if read as TimeMilliVector it indicates precision loss + TimeMilliVector time64Vector = (TimeMilliVector) time64FieldVector; + int time64Millis = time64Vector.get(rowIndex); + fail("col_time64 was read as TimeMilliVector instead of TimeMicroVector. " + + "This indicates that TIME(6) precision was not correctly parsed. " + + "Expected microseconds precision but got milliseconds. " + + "Value: " + time64Millis + " milliseconds"); + } else { + fail("Unexpected type for col_time64: " + + (time64FieldVector != null ? time64FieldVector.getClass().getName() : "null")); + } + assertNotNull(actualTime64, "col_time64 should not be null"); + assertEquals(expectedTime64.getHour(), actualTime64.getHour(), "col_time64 hour should be 12"); + assertEquals(expectedTime64.getMinute(), actualTime64.getMinute(), "col_time64 minute should be 30"); + assertEquals(expectedTime64.getSecond(), actualTime64.getSecond(), "col_time64 second should be 45"); + /* + * Note: JDBC driver returns java.sql.Time which only supports second precision, not microsecond precision + * Even though database TIME(6) can store microseconds, JDBC driver returns java.sql.Time which loses microsecond precision + * Therefore, we only verify up to second precision, not microsecond precision + * The actual nanoseconds will be 0, which is a JDBC driver limitation + */ + log.info("col_time64 nanoseconds: {} (expected: {}). Note: JDBC java.sql.Time only supports second precision", + actualTime64.getNano(), expectedTime64.getNano()); + + // Timestamp type verification (timestamps stored as UTC milliseconds, allow 1 second tolerance) + long expectedTimestampMs = java.sql.Timestamp.valueOf("2024-01-01 12:00:00").getTime(); + + FieldVector timestampFieldVector = root.getVector("col_timestamp"); + long actualTimestampMs = 0; + if (timestampFieldVector instanceof TimeStampMicroVector) { + // Convert microseconds to milliseconds + actualTimestampMs = ((TimeStampMicroVector) timestampFieldVector).get(rowIndex) / 1000; + } else if (timestampFieldVector instanceof TimeStampMilliVector) { + actualTimestampMs = ((TimeStampMilliVector) timestampFieldVector).get(rowIndex); + } + assertEquals(expectedTimestampMs, actualTimestampMs, 1000, + "col_timestamp should be approximately 2024-01-01 12:00:00 (within 1 second tolerance)"); + + TimeStampMilliVector timestampMsVector = (TimeStampMilliVector) root.getVector("col_timestamp_ms"); + long actualTimestampMsValue = timestampMsVector.get(rowIndex); + assertEquals(expectedTimestampMs, actualTimestampMsValue, 1000, + "col_timestamp_ms should be approximately 2024-01-01 12:00:00 (within 1 second tolerance)"); + + TimeStampMicroVector timestampUsVector = (TimeStampMicroVector) root.getVector("col_timestamp_us"); + long actualTimestampUsMs = timestampUsVector.get(rowIndex) / 1000; + assertEquals(expectedTimestampMs, actualTimestampUsMs, 1000, + "col_timestamp_us should be approximately 2024-01-01 12:00:00 (within 1 second tolerance)"); + + BitVector boolVector = (BitVector) root.getVector("col_bool"); + assertEquals(1, boolVector.get(rowIndex), "col_bool should be true (1)"); + + VarCharVector stringVector = (VarCharVector) root.getVector("col_string"); + String actualString = new String(stringVector.get(rowIndex)); + assertEquals("Test String", actualString, "col_string should be 'Test String'"); + + LargeVarCharVector largeStringVector = (LargeVarCharVector) root.getVector("col_large_string"); + String expectedLargeString = "Large String Content " + "A".repeat(100); + String actualLargeString = new String(largeStringVector.get(rowIndex)); + assertEquals(expectedLargeString, actualLargeString, "col_large_string should match expected value"); + + VarBinaryVector binaryVector = (VarBinaryVector) root.getVector("col_binary"); + byte[] expectedBinary = new byte[]{0x01, 0x02, 0x03, 0x04}; + byte[] actualBinary = binaryVector.get(rowIndex); + assertArrayEquals(expectedBinary, actualBinary, "col_binary should match expected bytes"); + + LargeVarBinaryVector largeBinaryVector = (LargeVarBinaryVector) root.getVector("col_large_binary"); + byte[] expectedLargeBinary = new byte[]{0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}; + byte[] actualLargeBinary = largeBinaryVector.get(rowIndex); + assertArrayEquals(expectedLargeBinary, actualLargeBinary, "col_large_binary should match expected bytes"); + + DecimalVector decimalVector = (DecimalVector) root.getVector("col_decimal"); + BigDecimal expectedDecimal = new BigDecimal("1234567890123456789012345678.1234567890"); + BigDecimal actualDecimal = decimalVector.getObject(rowIndex); + assertEquals(0, expectedDecimal.compareTo(actualDecimal), + "col_decimal should match expected value: " + expectedDecimal); + + IntervalYearVector intervalYVector = (IntervalYearVector) root.getVector("col_interval_ym"); + // 1 year = 12 months + int expectedIntervalYM = 12; + assertEquals(expectedIntervalYM, intervalYVector.get(rowIndex), "col_interval_ym should be 12 months"); + + IntervalDayVector intervalDVector = (IntervalDayVector) root.getVector("col_interval_dt"); + Object intervalObj = intervalDVector.getObject(rowIndex); + assertNotNull(intervalObj, "col_interval_dt should not be null"); + + // Verify specific values of interval types + int expectedDays = 1; + int expectedMillis = 0; + int actualDays = 0; + int actualMillis = 0; + + if (intervalObj instanceof PeriodDuration) { + PeriodDuration pd = (PeriodDuration) intervalObj; + Period period = pd.getPeriod(); + Duration duration = pd.getDuration(); + + actualDays = period.getDays(); + long totalMillis = duration.toMillis(); + actualMillis = (int) totalMillis; + } else if (intervalObj instanceof Duration) { + // Handle java.time.Duration directly (JDBC may return Duration instead of PeriodDuration) + Duration duration = (Duration) intervalObj; + long days = duration.toDays(); + long remainingMillis = duration.minusDays(days).toMillis(); + actualDays = (int) days; + actualMillis = (int) remainingMillis; + } else { + // Fallback: try reflection for PeriodDuration-like objects + try { + java.lang.reflect.Method getPeriodMethod = intervalObj.getClass().getMethod("getPeriod"); + java.lang.reflect.Method getDurationMethod = intervalObj.getClass().getMethod("getDuration"); + + Object period = getPeriodMethod.invoke(intervalObj); + Object duration = getDurationMethod.invoke(intervalObj); + + java.lang.reflect.Method getDaysMethod = period.getClass().getMethod("getDays"); + actualDays = (Integer) getDaysMethod.invoke(period); + + java.lang.reflect.Method toMillisMethod = duration.getClass().getMethod("toMillis"); + long millis = (Long) toMillisMethod.invoke(duration); + actualMillis = (int) millis; + } catch (Exception e) { + log.error("Failed to extract days and milliseconds from interval object of type: {}, value: {}", + intervalObj != null ? intervalObj.getClass().getName() : "null", intervalObj); + fail("Failed to extract days and milliseconds from interval object: " + e.getMessage() + + ". Object type: " + (intervalObj != null ? intervalObj.getClass().getName() : "null")); + } + } + + assertEquals(expectedDays, actualDays, "col_interval_dt days should be 1"); + assertEquals(expectedMillis, actualMillis, "col_interval_dt milliseconds should be 0"); + + log.info("✅ All SQL query data values verified successfully"); + } + + @Test + @Order(4) + public void testPartitionTableOverwrite() { + log.info("Testing partition table overwrite functionality..."); + + final String PARTITION_TABLE_NAME = "PARTITION_TEST_TABLE"; + final String PARTITION_COLUMN = "dt"; + final String PARTITION_VALUE_1 = "20240101"; + final String PARTITION_VALUE_2 = "20240102"; + + // Create partition table columns (partition column must be included in column definition) + List partitionColumns = Arrays.asList( + Common.DataColumn.newBuilder().setName("id").setType("int32").build(), + Common.DataColumn.newBuilder().setName("name").setType("string").build(), + Common.DataColumn.newBuilder().setName(PARTITION_COLUMN).setType("string").build() + ); + + assertDoesNotThrow(() -> { + // Setup partition table + try (Connection conn = DriverManager.getConnection(damengJdbcUrl, damengUser, damengPassword); + Statement stmt = conn.createStatement()) { + stmt.execute(String.format("DROP TABLE IF EXISTS %s", PARTITION_TABLE_NAME)); + + // Create partition table (Dameng database partition columns must be included in column definitions) + String createPartitionTableSql = String.format( + "CREATE TABLE %s (" + + "id INT, " + + "name VARCHAR(100), " + + "%s VARCHAR(20)" + + ")", PARTITION_TABLE_NAME, PARTITION_COLUMN); + stmt.execute(createPartitionTableSql); + log.info("Created partition table: {}", PARTITION_TABLE_NAME); + } + }); + + // Create DomainData with partition + Domaindata.DomainData partitionDomainData = Domaindata.DomainData.newBuilder() + .setDatasourceId("dameng-datasource") + .setName(PARTITION_TABLE_NAME) + .setRelativeUri(PARTITION_TABLE_NAME) + .setDomaindataId("dameng-partition-table") + .setType("table") + .addAllColumns(partitionColumns) + .build(); + + // Test 1: Write data to partition 1 + assertDoesNotThrow(() -> { + Flightinner.CommandDataMeshUpdate command1 = Flightinner.CommandDataMeshUpdate.newBuilder() + .setDatasource(domainDataSource) + .setDomaindata(partitionDomainData) + .setUpdate(Flightdm.CommandDomainDataUpdate.newBuilder() + .setContentType(Flightdm.ContentType.CSV) + .setPartitionSpec(String.format("%s=%s", PARTITION_COLUMN, PARTITION_VALUE_1))) + .build(); + + FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command1).toByteArray()); + FlightInfo info = client.getInfo(descriptor, CallOptions.timeout(10, TimeUnit.SECONDS)); + Ticket ticket = info.getEndpoints().get(0).getTicket(); + + Schema schema = new Schema(partitionColumns.stream() + .map(col -> Field.nullable(col.getName(), ArrowUtil.parseKusciaColumnType(col.getType()))) + .collect(Collectors.toList())); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + VarCharVector dtVector = (VarCharVector) root.getVector(PARTITION_COLUMN); + + // Write first batch of data (include partition column value) + idVector.setSafe(0, 1); + nameVector.setSafe(0, "First Batch".getBytes()); + dtVector.setSafe(0, PARTITION_VALUE_1.getBytes()); + idVector.setSafe(1, 2); + nameVector.setSafe(1, "First Batch 2".getBytes()); + dtVector.setSafe(1, PARTITION_VALUE_1.getBytes()); + + root.setRowCount(2); + + FlightClient.ClientStreamListener listener = client.startPut( + FlightDescriptor.command(ticket.getBytes()), root, new AsyncPutListener()); + listener.putNext(); + listener.completed(); + listener.getResult(); + + log.info("Wrote 2 rows to partition {}", PARTITION_VALUE_1); + } + }); + + // Verify partition 1 has 2 rows + assertDoesNotThrow(() -> { + try (Connection conn = DriverManager.getConnection(damengJdbcUrl, damengUser, damengPassword); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery( + String.format("SELECT COUNT(*) as cnt FROM %s WHERE %s = '%s'", + PARTITION_TABLE_NAME, PARTITION_COLUMN, PARTITION_VALUE_1))) { + assertTrue(rs.next(), "Should have result"); + assertEquals(2, rs.getInt("cnt"), "Partition " + PARTITION_VALUE_1 + " should have 2 rows"); + log.info("Verified partition {} has 2 rows", PARTITION_VALUE_1); + } + }); + + // Test 2: Overwrite partition 1 with new data + assertDoesNotThrow(() -> { + Flightinner.CommandDataMeshUpdate command2 = Flightinner.CommandDataMeshUpdate.newBuilder() + .setDatasource(domainDataSource) + .setDomaindata(partitionDomainData) + .setUpdate(Flightdm.CommandDomainDataUpdate.newBuilder() + .setContentType(Flightdm.ContentType.CSV) + .setPartitionSpec(String.format("%s=%s", PARTITION_COLUMN, PARTITION_VALUE_1))) + .build(); + + FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command2).toByteArray()); + FlightInfo info = client.getInfo(descriptor, CallOptions.timeout(10, TimeUnit.SECONDS)); + Ticket ticket = info.getEndpoints().get(0).getTicket(); + + Schema schema = new Schema(partitionColumns.stream() + .map(col -> Field.nullable(col.getName(), ArrowUtil.parseKusciaColumnType(col.getType()))) + .collect(Collectors.toList())); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + VarCharVector dtVector = (VarCharVector) root.getVector(PARTITION_COLUMN); + + // Write new data to same partition (should overwrite) + idVector.setSafe(0, 10); + nameVector.setSafe(0, "Overwritten Batch".getBytes()); + dtVector.setSafe(0, PARTITION_VALUE_1.getBytes()); + idVector.setSafe(1, 11); + nameVector.setSafe(1, "Overwritten Batch 2".getBytes()); + dtVector.setSafe(1, PARTITION_VALUE_1.getBytes()); + idVector.setSafe(2, 12); + nameVector.setSafe(2, "Overwritten Batch 3".getBytes()); + dtVector.setSafe(2, PARTITION_VALUE_1.getBytes()); + + root.setRowCount(3); + + FlightClient.ClientStreamListener listener = client.startPut( + FlightDescriptor.command(ticket.getBytes()), root, new AsyncPutListener()); + listener.putNext(); + listener.completed(); + listener.getResult(); + + log.info("Overwrote partition {} with 3 new rows", PARTITION_VALUE_1); + } + }); + + // Verify partition 1 now has only 3 rows (overwritten) + assertDoesNotThrow(() -> { + try (Connection conn = DriverManager.getConnection(damengJdbcUrl, damengUser, damengPassword); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery( + String.format("SELECT COUNT(*) as cnt FROM %s WHERE %s = '%s'", + PARTITION_TABLE_NAME, PARTITION_COLUMN, PARTITION_VALUE_1))) { + assertTrue(rs.next(), "Should have result"); + assertEquals(3, rs.getInt("cnt"), + "Partition " + PARTITION_VALUE_1 + " should have 3 rows after overwrite"); + + // Verify old data is gone + java.sql.ResultSet rs2 = stmt.executeQuery( + String.format("SELECT COUNT(*) as cnt FROM %s WHERE %s = '%s' AND name = 'First Batch'", + PARTITION_TABLE_NAME, PARTITION_COLUMN, PARTITION_VALUE_1)); + assertTrue(rs2.next(), "Should have result"); + assertEquals(0, rs2.getInt("cnt"), + "Old data should be overwritten"); + + log.info("Verified partition {} was overwritten correctly", PARTITION_VALUE_1); + } + }); + + // Test 3: Write data to partition 2 (different partition should not be affected) + assertDoesNotThrow(() -> { + Flightinner.CommandDataMeshUpdate command3 = Flightinner.CommandDataMeshUpdate.newBuilder() + .setDatasource(domainDataSource) + .setDomaindata(partitionDomainData) + .setUpdate(Flightdm.CommandDomainDataUpdate.newBuilder() + .setContentType(Flightdm.ContentType.CSV) + .setPartitionSpec(String.format("%s=%s", PARTITION_COLUMN, PARTITION_VALUE_2))) + .build(); + + FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(command3).toByteArray()); + FlightInfo info = client.getInfo(descriptor, CallOptions.timeout(10, TimeUnit.SECONDS)); + Ticket ticket = info.getEndpoints().get(0).getTicket(); + + Schema schema = new Schema(partitionColumns.stream() + .map(col -> Field.nullable(col.getName(), ArrowUtil.parseKusciaColumnType(col.getType()))) + .collect(Collectors.toList())); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + VarCharVector dtVector = (VarCharVector) root.getVector(PARTITION_COLUMN); + + idVector.setSafe(0, 20); + nameVector.setSafe(0, "Partition 2 Data".getBytes()); + dtVector.setSafe(0, PARTITION_VALUE_2.getBytes()); + + root.setRowCount(1); + + FlightClient.ClientStreamListener listener = client.startPut( + FlightDescriptor.command(ticket.getBytes()), root, new AsyncPutListener()); + listener.putNext(); + listener.completed(); + listener.getResult(); + + log.info("Wrote 1 row to partition {}", PARTITION_VALUE_2); + } + }); + + // Verify both partitions exist independently + assertDoesNotThrow(() -> { + try (Connection conn = DriverManager.getConnection(damengJdbcUrl, damengUser, damengPassword); + Statement stmt = conn.createStatement()) { + + // Verify partition 1 still has 3 rows + java.sql.ResultSet rs1 = stmt.executeQuery( + String.format("SELECT COUNT(*) as cnt FROM %s WHERE %s = '%s'", + PARTITION_TABLE_NAME, PARTITION_COLUMN, PARTITION_VALUE_1)); + assertTrue(rs1.next(), "Should have result"); + assertEquals(3, rs1.getInt("cnt"), + "Partition " + PARTITION_VALUE_1 + " should still have 3 rows"); + + // Verify partition 2 has 1 row + java.sql.ResultSet rs2 = stmt.executeQuery( + String.format("SELECT COUNT(*) as cnt FROM %s WHERE %s = '%s'", + PARTITION_TABLE_NAME, PARTITION_COLUMN, PARTITION_VALUE_2)); + assertTrue(rs2.next(), "Should have result"); + assertEquals(1, rs2.getInt("cnt"), + "Partition " + PARTITION_VALUE_2 + " should have 1 row"); + + log.info("Verified both partitions exist independently"); + } + }); + + // Cleanup + assertDoesNotThrow(() -> { + try (Connection conn = DriverManager.getConnection(damengJdbcUrl, damengUser, damengPassword); + Statement stmt = conn.createStatement()) { + stmt.execute(String.format("DROP TABLE IF EXISTS %s", PARTITION_TABLE_NAME)); + log.info("Cleaned up partition test table"); + } + }); + } +} diff --git a/dataproxy-integration-tests/src/test/java/org/secretflow/dataproxy/integration/tests/utils/DamengTestUtil.java b/dataproxy-integration-tests/src/test/java/org/secretflow/dataproxy/integration/tests/utils/DamengTestUtil.java new file mode 100644 index 0000000..7b65882 --- /dev/null +++ b/dataproxy-integration-tests/src/test/java/org/secretflow/dataproxy/integration/tests/utils/DamengTestUtil.java @@ -0,0 +1,71 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.secretflow.dataproxy.integration.tests.utils; + +import java.io.InputStream; +import java.util.Properties; + +/** + * Dameng database test utility class. + */ +public class DamengTestUtil { + + private static final Properties properties = new Properties(); + + static { + try (InputStream is = DamengTestUtil.class.getResourceAsStream("/test-dameng.conf")) { + if (is != null) { + properties.load(is); + } + } catch (Exception e) { + // If config file doesn't exist, use default values or environment variables + System.err.println("Warning: Could not load test-dameng.conf, using environment variables or defaults"); + } + } + + public static String getDamengEndpoint() { + String endpoint = properties.getProperty("test.dameng.endpoint"); + if (endpoint == null || endpoint.isEmpty()) { + endpoint = System.getenv("TEST_DAMENG_ENDPOINT"); + } + return endpoint != null ? endpoint : "localhost:5236"; + } + + public static String getDamengDatabase() { + String database = properties.getProperty("test.dameng.database"); + if (database == null || database.isEmpty()) { + database = System.getenv("TEST_DAMENG_DATABASE"); + } + return database != null ? database : "SYSDBA"; + } + + public static String getDamengUsername() { + String username = properties.getProperty("test.dameng.username"); + if (username == null || username.isEmpty()) { + username = System.getenv("TEST_DAMENG_USERNAME"); + } + return username != null ? username : "SYSDBA"; + } + + public static String getDamengPassword() { + String password = properties.getProperty("test.dameng.password"); + if (password == null || password.isEmpty()) { + password = System.getenv("TEST_DAMENG_PASSWORD"); + } + return password != null ? password : "SYSDBA"; + } +} + diff --git a/dataproxy-integration-tests/src/test/resources/test-dameng.conf b/dataproxy-integration-tests/src/test/resources/test-dameng.conf new file mode 100644 index 0000000..459de9a --- /dev/null +++ b/dataproxy-integration-tests/src/test/resources/test-dameng.conf @@ -0,0 +1,6 @@ +test.dameng.endpoint= +test.dameng.database= +test.dameng.username= +test.dameng.password= + + diff --git a/dataproxy-plugins/dataproxy-plugin-dameng/pom.xml b/dataproxy-plugins/dataproxy-plugin-dameng/pom.xml new file mode 100644 index 0000000..cccd4d0 --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-dameng/pom.xml @@ -0,0 +1,77 @@ + + + 4.0.0 + + org.secretflow + dataproxy-plugins + 0.0.1-SNAPSHOT + ../pom.xml + + + dataproxy-plugin-dameng + jar + + dataproxy-plugin-dameng + https://maven.apache.org + + + + org.secretflow + dataproxy-plugin-database + + + + com.dameng + DmJdbcDriver18 + 8.1.3.62 + + + + org.apache.arrow + arrow-format + + + + + org.junit.jupiter + junit-jupiter-api + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + org.mockito + mockito-inline + test + + + + + org.mockito + mockito-junit-jupiter + test + + + + uk.org.webcompere + system-stubs-jupiter + test + + + org.projectlombok + lombok + compile + + + org.apache.arrow + arrow-memory-netty + test + + + + \ No newline at end of file diff --git a/dataproxy-plugins/dataproxy-plugin-dameng/src/main/java/org/secretflow/dataproxy/producer/DamengFlightProducer.java b/dataproxy-plugins/dataproxy-plugin-dameng/src/main/java/org/secretflow/dataproxy/producer/DamengFlightProducer.java new file mode 100644 index 0000000..78615d1 --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-dameng/src/main/java/org/secretflow/dataproxy/producer/DamengFlightProducer.java @@ -0,0 +1,94 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Dameng database Flight Producer implementation. + * + *

This producer extends AbstractDatabaseFlightProducer and provides + * Dameng-specific implementations for connection initialization, SQL building, + * and type conversion.

+ * + */ +package org.secretflow.dataproxy.producer; +import org.secretflow.dataproxy.plugin.database.config.DatabaseCommandConfig; +import org.secretflow.dataproxy.plugin.database.config.DatabaseWriteConfig; +import org.secretflow.dataproxy.plugin.database.producer.AbstractDatabaseFlightProducer; +import org.secretflow.dataproxy.plugin.database.reader.DatabaseDoGetContext; +import org.secretflow.dataproxy.plugin.database.writer.DatabaseRecordWriter; +import org.secretflow.dataproxy.util.DamengUtil; + +/** + * @author: kongxiaoran + * @date: 2025/11/5 + */ +public class DamengFlightProducer extends AbstractDatabaseFlightProducer { + + /** + * Returns the producer name used for SPI registration. + * + * @return "dameng" + */ + @Override + public String getProducerName() { + return "dameng"; + } + + /** + * Initializes the database read context with Dameng-specific implementations. + * + * @param config Command configuration + * @return DatabaseDoGetContext for reading data + */ + @Override + protected DatabaseDoGetContext initDoGetContext(DatabaseCommandConfig config) { + /* + * Initialize database read context with Dameng-specific implementations: + * 1. Initialize JDBC connection + * 2. Build SELECT SQL + * 3. Convert JDBC types to Arrow types + */ + return new DatabaseDoGetContext( + config, + DamengUtil::initDameng, + DamengUtil::buildQuerySql, + DamengUtil::jdbcType2ArrowType + ); + } + + /** + * Initializes the database write context with Dameng-specific implementations. + * + * @param config Write configuration + * @return DatabaseRecordWriter for writing data + */ + @Override + protected DatabaseRecordWriter initRecordWriter(DatabaseWriteConfig config) { + /* + * Initialize database write context with Dameng-specific implementations: + * 1. Initialize JDBC connection + * 2. Build CREATE TABLE SQL + * 3. Build batch INSERT SQL + * 4. Check if table exists + */ + return new DatabaseRecordWriter( + config, + DamengUtil::initDameng, + DamengUtil::buildCreateTableSql, + DamengUtil::buildMultiRowInsertSql, + DamengUtil::checkTableExists + ); + } +} diff --git a/dataproxy-plugins/dataproxy-plugin-dameng/src/main/java/org/secretflow/dataproxy/util/DamengUtil.java b/dataproxy-plugins/dataproxy-plugin-dameng/src/main/java/org/secretflow/dataproxy/util/DamengUtil.java new file mode 100644 index 0000000..0a09eb6 --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-dameng/src/main/java/org/secretflow/dataproxy/util/DamengUtil.java @@ -0,0 +1,769 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.secretflow.dataproxy.util; + +import dm.jdbc.util.StringUtil; +import lombok.extern.slf4j.Slf4j; +import org.apache.arrow.vector.PeriodDuration; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.secretflow.dataproxy.common.exceptions.DataproxyErrorCode; +import org.secretflow.dataproxy.common.exceptions.DataproxyException; +import org.secretflow.dataproxy.plugin.database.config.DatabaseConnectConfig; +import org.secretflow.dataproxy.plugin.database.writer.DatabaseRecordWriter; + +import java.sql.*; +import java.time.*; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.secretflow.dataproxy.plugin.database.writer.DatabaseRecordWriter.parsePartition; + +/** + * Utility class for Dameng database operations. + */ +@Slf4j +public class DamengUtil { + + private static final Pattern IDENTIFIER_PATTERN = Pattern.compile("^[a-zA-Z][a-zA-Z0-9_]*$"); + private static final Pattern PRECISION_PATTERN = Pattern.compile("\\((\\d+)\\)"); + + /** + * Initialize Dameng database connection. + * + * @param config Database connection configuration + * @return JDBC connection object + * @throws RuntimeException if connection fails + */ + public static Connection initDameng(DatabaseConnectConfig config) { + String endpoint = config.endpoint(); + String ip; + // Default Dameng database port + int port = 5236; + + /* + * Parse endpoint address (host:port) + * Note: IPv6 format (e.g., [::1]:5236) is not supported, only IPv4 or hostname + */ + if (endpoint.contains(":")) { + // Limit split to avoid IPv6 address issues + String[] parts = endpoint.split(":", 2); + ip = parts[0]; + if (parts.length > 1 && !parts[1].isEmpty()) { + try { + port = Integer.parseInt(parts[1]); + if (port < 1 || port > 65535) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, + "Invalid port number: " + port + ". Port must be between 1 and 65535."); + } + } catch (NumberFormatException e) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, + "Invalid port format in endpoint: " + endpoint); + } + } + } else { + ip = endpoint; + } + + // Validate IP address or hostname to prevent JDBC URL injection + if (ip == null || !ip.matches("^[a-zA-Z0-9._-]+$") || ip.contains("..") || + ip.startsWith(".") || ip.endsWith(".")) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, + "Invalid IP address or hostname: " + ip); + } + + // Validate database name to prevent JDBC URL injection + String database = config.database(); + if (database == null || !database.matches("^[a-zA-Z0-9_]+$")) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, + "Invalid database name: " + database); + } + + Connection conn; + try { + Class.forName("dm.jdbc.driver.DmDriver"); + String url = String.format("jdbc:dm://%s:%d/%s", ip, port, database); + log.info("Connecting to Dameng database: {}", url.replaceAll("(password:)[^@]*", "$1****")); + + conn = DriverManager.getConnection(url, config.username(), config.password()); + log.info("Successfully connected to Dameng database"); + } catch (ClassNotFoundException e) { + log.error("Dameng JDBC driver not found", e); + throw new RuntimeException("Dameng JDBC driver not found. Please ensure DmJdbcDriver18 is in classpath.", e); + } catch (SQLException e) { + log.error("Failed to connect to Dameng database", e); + throw new RuntimeException("Failed to connect to Dameng database: " + e.getMessage(), e); + } catch (Exception e) { + log.error("Unexpected error connecting to Dameng database", e); + throw new RuntimeException(e); + } + return conn; + } + + + /** + * Build SELECT query SQL statement based on table name, columns, and partition specification. + * + * @param tableName Table name + * @param columns Column name list + * @param partitionClause Partition specification (e.g., "dt=20240101") + * @return SELECT SQL statement + */ + public static String buildQuerySql(String tableName, List columns, String partitionClause) { + if (columns == null || columns.isEmpty()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "columns cannot be empty"); + } + + if (!IDENTIFIER_PATTERN.matcher(tableName).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid tableName:" + tableName); + } + + for (String field : columns) { + if (!IDENTIFIER_PATTERN.matcher(field).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid field name:" + field); + } + } + + String sql = "SELECT " + String.join(", ", columns) + " FROM " + tableName; + + if (partitionClause != null && !partitionClause.trim().isEmpty()) { + final Map partitionSpec = parsePartition(partitionClause); + List conditions = new ArrayList<>(); + + for (Map.Entry entry : partitionSpec.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + + if (!IDENTIFIER_PATTERN.matcher(key).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid partition key:" + key); + } + + if (!value.matches("^[a-zA-Z0-9_.-]+$")) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid partition value:" + value); + } + + conditions.add(key + "='" + escapeString(value) + "'"); + } + String processedPartition = String.join(" AND ", conditions); + + sql += " WHERE " + processedPartition; + } + + log.info("Built query SQL: {}", sql); + return sql; + } + + /** + * Build CREATE TABLE SQL statement. + * + * @param tableName Table name + * @param schema Arrow Schema + * @param partition Partition specification map (only for validating partition key names; partition columns must be included in column definitions in Dameng) + * @return CREATE TABLE SQL statement + */ + public static String buildCreateTableSql(String tableName, Schema schema, Map partition) { + if (!IDENTIFIER_PATTERN.matcher(tableName).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid tableName:" + tableName); + } + + List fields = schema.getFields(); + + if (fields.isEmpty()) { + throw DataproxyException.of( + DataproxyErrorCode.PARAMS_UNRELIABLE, + "Table must have at least one column. Empty schema is not allowed in Dameng database." + ); + } + + for (Field field : fields) { + String fieldName = field.getName(); + if (!IDENTIFIER_PATTERN.matcher(fieldName).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid field name:" + fieldName); + } + } + + if (partition != null) { + for (String partKey : partition.keySet()) { + if (!IDENTIFIER_PATTERN.matcher(partKey).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid partition key:" + partKey); + } + } + } + + // Note: In Dameng database, partition columns must be included in column definitions, unlike Hive where partition columns are defined separately in PARTITIONED BY clause + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE ").append(tableName).append(" (\n"); + + String columnDefinitions = fields.stream() + .map(field -> " " + field.getName() + " " + arrowTypeToJdbcType(field.getType())) + .collect(Collectors.joining(",\n")); + + sb.append(columnDefinitions).append("\n)"); + + log.info("Built CREATE TABLE SQL: {}", sb); + return sb.toString(); + } + + + /** + * Convert Arrow type to Dameng JDBC type. + * + * @param arrowType Arrow type + * @return JDBC type name + */ + public static String arrowTypeToJdbcType(ArrowType arrowType) { + if (arrowType instanceof ArrowType.Utf8) { + return "VARCHAR"; + } else if (arrowType instanceof ArrowType.LargeUtf8) { + return "CLOB"; + } + else if (arrowType instanceof ArrowType.Int intType) { + int bitWidth = intType.getBitWidth(); + return switch (bitWidth) { + case 8 -> "TINYINT"; + case 16 -> "SMALLINT"; + case 32 -> "INT"; + case 64 -> "BIGINT"; + default -> throw new IllegalArgumentException("Unsupported Int bitWidth: " + bitWidth); + }; + } + else if (arrowType instanceof ArrowType.FloatingPoint fp) { + return switch (fp.getPrecision()) { + case SINGLE -> "FLOAT"; + case DOUBLE -> "DOUBLE"; + default -> throw new IllegalArgumentException("Unsupported floating point type"); + }; + } + else if (arrowType instanceof ArrowType.Bool) { + // Dameng uses BIT type for boolean: 1=true, 0=false, NULL=null + return "BIT"; + } + else if (arrowType instanceof ArrowType.Date dateType) { + return switch (dateType.getUnit()) { + case DAY -> "DATE"; + // Date(MILLISECOND) maps to DATETIME(3), distinct from Timestamp + case MILLISECOND -> "DATETIME(3)"; + }; + } else if (arrowType instanceof ArrowType.Time timeType) { + /* + * Arrow Time type bitWidth must be 32 or 64 (standard), not precision + * For Dameng database, determine precision based on TimeUnit instead of bitWidth + */ + int precision = switch (timeType.getUnit()) { + // TIME(0) - seconds precision + case SECOND -> 0; + // TIME(3) - milliseconds precision + case MILLISECOND -> 3; + // TIME(6) - microseconds precision + case MICROSECOND -> 6; + case NANOSECOND -> { + log.warn("Dameng database does not support nanosecond precision for TIME type. Using microseconds (6) instead."); + // Fallback to microseconds + yield 6; + } + }; + return "TIME(" + precision + ")"; + } else if (arrowType instanceof ArrowType.Timestamp timestampType) { + /* + * Only precision is specified when creating table; timezone info is not specified at table creation + * It needs to be concatenated with TIMESTAMP when inserting data, format: 2002.12.12 09:10:21 -5:00 + */ + String damengType = "TIMESTAMP"; + if (StringUtil.isNotEmpty(timestampType.getTimezone())) { + damengType = "TIMESTAMP WITH TIME ZONE"; + } + return damengType + switch (timestampType.getUnit()) { + case SECOND -> "(0)"; + case MILLISECOND -> "(3)"; + case MICROSECOND -> "(6)"; + case NANOSECOND -> throw new IllegalArgumentException( + "Dameng currently does not support nanosecond level accuracy"); + }; + } + else if (arrowType instanceof ArrowType.Decimal dec) { + int precision = dec.getPrecision(); + int scale = dec.getScale(); + + // Validate Dameng database precision limits: precision range 1-38, scale range 0-precision + if (precision < 1 || precision > 38) { + throw new IllegalArgumentException( + String.format("DECIMAL precision %d out of range [1, 38] for Dameng database", precision)); + } + if (scale < 0 || scale > precision) { + throw new IllegalArgumentException( + String.format("DECIMAL scale %d out of range [0, %d]", scale, precision)); + } + return "DECIMAL(" + precision + ", " + scale + ")"; + } + else if (arrowType instanceof ArrowType.Binary) { + // Dameng database VARBINARY default length is 8188 bytes + return "VARBINARY"; + } else if (arrowType instanceof ArrowType.FixedSizeBinary fixedBinary) { + int byteWidth = fixedBinary.getByteWidth(); + if (byteWidth <= 8188) { + // Dameng database BINARY maximum length is 8188 bytes + return "BINARY(" + byteWidth + ")"; + } else { + log.warn("FixedSizeBinary byteWidth {} exceeds BINARY/VARBINARY limit (8188), using BLOB", byteWidth); + return "BLOB"; + } + } else if (arrowType instanceof ArrowType.LargeBinary) { + // Large binary object (maximum length 100G-1 bytes) + return "BLOB"; + } + else if (arrowType instanceof ArrowType.Interval intervalType) { + return switch (intervalType.getUnit()) { + case YEAR_MONTH -> "INTERVAL YEAR TO MONTH"; + case DAY_TIME -> "INTERVAL DAY TO SECOND"; + default -> throw new IllegalArgumentException("Unsupported Interval unit: " + intervalType.getUnit()); + }; + } + else if (arrowType instanceof ArrowType.Null) { + // NULL type is usually not used alone in actual table creation, return VARCHAR as placeholder + return "VARCHAR(1)"; + } + else { + throw new IllegalArgumentException("Unsupported Arrow type: " + arrowType.getClass().getName()); + } + } + + /** + * Extract precision information from JDBC type name. + * Example: TIMESTAMP(3) -> 3, DATETIME(3) -> 3 + * + * @param jdbcType JDBC type name (may contain precision, e.g., "TIMESTAMP(3)") + * @return Precision value, or null if no precision information + */ + private static Integer parsePrecisionFromTypeName(String jdbcType) { + if (jdbcType == null || jdbcType.isEmpty()) { + return null; + } + java.util.regex.Matcher matcher = PRECISION_PATTERN.matcher(jdbcType); + if (matcher.find()) { + try { + return Integer.parseInt(matcher.group(1)); + } catch (NumberFormatException e) { + log.warn("Failed to parse precision from JDBC type: {}", jdbcType); + return null; + } + } + return null; + } + + /** + * Extract base type name from JDBC type name (remove precision information). + * Example: TIMESTAMP(3) -> TIMESTAMP, DATETIME(3) -> DATETIME + * + * @param jdbcType JDBC type name + * @return Base type name + */ + private static String extractBaseType(String jdbcType) { + if (jdbcType == null || jdbcType.isEmpty()) { + return jdbcType; + } + return jdbcType.replaceAll("\\(\\d+\\)", "").trim().toUpperCase(); + } + + /** + * Convert Dameng JDBC type to Arrow type. + * Note: This method attempts to parse precision information from TYPE_NAME (e.g., TIMESTAMP(3)) + * + * @param jdbcType JDBC type name (may contain precision, e.g., "TIMESTAMP(3)" or "DATETIME(3)") + * @return Arrow type + */ + public static ArrowType jdbcType2ArrowType(String jdbcType) { + if (jdbcType == null || jdbcType.isEmpty()) { + throw new IllegalArgumentException("JDBC type is null or empty"); + } + + String baseType = extractBaseType(jdbcType); + Integer precision = parsePrecisionFromTypeName(jdbcType); + + return switch (baseType) { + case "DECIMAL", "NUMERIC", "DEC", "NUMBER" -> { + // Use default precision if precision info is not provided (backward compatibility) + yield new ArrowType.Decimal(38, 10, 128); + } + + case "TIMESTAMP" -> { + org.apache.arrow.vector.types.TimeUnit timeUnit; + if (precision != null && precision <= 3) { + // TIMESTAMP(0-3): millisecond precision + timeUnit = org.apache.arrow.vector.types.TimeUnit.MILLISECOND; + } else if (precision != null && precision <= 6) { + // TIMESTAMP(4-6): microsecond precision + timeUnit = org.apache.arrow.vector.types.TimeUnit.MICROSECOND; + } else { + // Default to microsecond precision if precision is unknown (for compatibility) + timeUnit = org.apache.arrow.vector.types.TimeUnit.MICROSECOND; + } + yield new ArrowType.Timestamp(timeUnit, null); + } + + /* + * DATETIME(3) -> Date(MILLISECOND) + * DATETIME(6) -> Timestamp(MICROSECOND) (Date doesn't support microseconds, use Timestamp) + */ + case "DATETIME" -> { + if (precision != null && precision == 3) { + // DATETIME(3): millisecond precision -> Date(MILLISECOND) + yield new ArrowType.Date(org.apache.arrow.vector.types.DateUnit.MILLISECOND); + } else if (precision != null && precision == 6) { + // DATETIME(6): microsecond precision -> Timestamp(MICROSECOND) (Date doesn't support microseconds) + log.warn("DATETIME(6) precision not supported by Date type, mapping to Timestamp(MICROSECOND)"); + yield new ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.MICROSECOND, null); + } else { + // Default to millisecond precision if precision is unknown + yield new ArrowType.Date(org.apache.arrow.vector.types.DateUnit.MILLISECOND); + } + } + + /* + * Arrow standard: bitWidth must be 32 or 64, not precision + * Use Time32 (32-bit) for millisecond precision, Time64 (64-bit) for microsecond precision + * TIME(3) -> Time(MILLISECOND, 32) -> TimeMilliVector + * TIME(6) -> Time(MICROSECOND, 64) -> TimeMicroVector + */ + case "TIME" -> { + org.apache.arrow.vector.types.TimeUnit timeUnit; + int bitWidth; + if (precision != null && precision <= 3) { + // TIME(0-3): millisecond precision -> Time32 + timeUnit = org.apache.arrow.vector.types.TimeUnit.MILLISECOND; + bitWidth = 32; + } else if (precision != null && precision <= 6) { + // TIME(4-6): microsecond precision -> Time64 + timeUnit = org.apache.arrow.vector.types.TimeUnit.MICROSECOND; + bitWidth = 64; + } else { + // Default to millisecond precision if precision is unknown + timeUnit = org.apache.arrow.vector.types.TimeUnit.MILLISECOND; + bitWidth = 32; + } + yield new ArrowType.Time(timeUnit, bitWidth); + } + + case "TIMESTAMP_WITH_TIMEZONE", "TIMESTAMP WITH TIME ZONE", + "TIMESTAMP WITH LOCAL TIME ZONE", "DATETIME_TZ", "DATETIME WITH TIME ZONE" -> { + yield new ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.MICROSECOND, "UTC"); + } + + case "BINARY" -> { + yield new ArrowType.Binary(); + } + + case "CHAR", "CHARACTER", "VARCHAR", "VARCHAR2" -> ArrowType.Utf8.INSTANCE; + // Large string types - TEXT/LONG/LONGVARCHAR/CLOB all map to LargeUtf8 + case "TEXT", "LONG", "LONGVARCHAR", "CLOB" -> ArrowType.LargeUtf8.INSTANCE; + + // BYTE is an alias for TINYINT + case "TINYINT", "BYTE" -> new ArrowType.Int(8, true); + case "SMALLINT" -> new ArrowType.Int(16, true); + case "INT", "INTEGER", "PLS_INTEGER" -> new ArrowType.Int(32, true); + case "BIGINT" -> new ArrowType.Int(64, true); + + case "FLOAT", "REAL" -> new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + case "DOUBLE", "DOUBLE PRECISION" -> new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + + // Dameng uses BIT type for boolean: 1=true, 0=false, NULL=null + case "BOOLEAN", "BIT" -> ArrowType.Bool.INSTANCE; + + case "DATE" -> new ArrowType.Date(DateUnit.DAY); + + // RAW is an alias for VARBINARY + case "VARBINARY", "RAW" -> ArrowType.Binary.INSTANCE; + + // Large binary types - IMAGE/LONGVARBINARY/BLOB all map to LargeBinary + case "IMAGE", "LONGVARBINARY", "BLOB" -> ArrowType.LargeBinary.INSTANCE; + + // Year-month interval types: all year-month related interval types map to YEAR_MONTH + case "INTERVAL_YM", "INTERVAL YEAR TO MONTH", "INTERVAL YEAR", "INTERVAL MONTH" -> + new ArrowType.Interval(IntervalUnit.YEAR_MONTH); + + // Day-time interval types: all day-time related interval types map to DAY_TIME + case "INTERVAL_DT", "INTERVAL DAY TO TIME", "INTERVAL DAY TO SECOND", + "INTERVAL DAY", "INTERVAL DAY TO HOUR", "INTERVAL DAY TO MINUTE", + "INTERVAL HOUR", "INTERVAL HOUR TO MINUTE", "INTERVAL HOUR TO SECOND", + "INTERVAL MINUTE", "INTERVAL MINUTE TO SECOND", "INTERVAL SECOND" -> + new ArrowType.Interval(IntervalUnit.DAY_TIME); + + case "NULL" -> ArrowType.Null.INSTANCE; + + default -> { + log.warn("Unsupported JDBC type: {}, using Utf8 as fallback", jdbcType); + yield ArrowType.Utf8.INSTANCE; + } + }; + } + + + + public DamengUtil() { + } + + /** + * Build multi-row INSERT SQL statement (parameterized). + * + * @param tableName Table name + * @param schema Arrow Schema + * @param dataList Data row list + * @param partition Partition specification map + * @return SqlWithParams object (contains SQL and parameter list) + */ + public static DatabaseRecordWriter.SqlWithParams buildMultiRowInsertSql(String tableName, + Schema schema, + List> dataList, + Map partition) { + if (dataList == null || dataList.isEmpty()) { + throw new IllegalArgumentException("No data to insert"); + } + + if (!IDENTIFIER_PATTERN.matcher(tableName).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid tableName:" + tableName); + } + + for (Field f : schema.getFields()) { + if (!IDENTIFIER_PATTERN.matcher(f.getName()).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid field name:" + f.getName()); + } + } + + for (String k : partition.keySet()) { + if (!IDENTIFIER_PATTERN.matcher(k).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid partition key:" + k); + } + } + + List fields = schema.getFields(); + List allColumns = fields.stream() + .map(Field::getName) + .collect(Collectors.toList()); + Map fieldMap = fields.stream() + .collect(Collectors.toMap(Field::getName, f -> f)); + + // Build SQL statement (INTERVAL types use literal embedding) + StringBuilder sb = new StringBuilder(); + sb.append("INSERT INTO ").append(tableName); + sb.append(" (").append(String.join(", ", allColumns)).append(")"); + sb.append(" VALUES "); + + // Build VALUES clause for each row (INTERVAL types directly embed literals) + List params = new ArrayList<>(); + List valueClauses = new ArrayList<>(); + Set partitionKeys = partition.keySet(); + // Define datetime format with timezone offset supported by Dameng + DateTimeFormatter dtf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS XXX"); + + for (Map row : dataList) { + List rowValues = new ArrayList<>(); + for (String colName : allColumns) { + Object value; + if (partitionKeys.contains(colName)) { + value = partition.get(colName); + } else { + // Use lowercase column name to get value, because DatabaseRecordWriter.write() uses lowercase keys + value = row.get(colName.toLowerCase()); + } + + Field field = fieldMap.get(colName); + ArrowType fieldType = field.getType(); + + if (fieldType instanceof ArrowType.Timestamp tsType) { + // Handle timestamp with timezone + if (tsType.getTimezone() != null && !tsType.getTimezone().isEmpty() && value instanceof Long) { + Instant instant = Instant.ofEpochMilli((Long) value); + ZoneId zoneId = ZoneId.of(tsType.getTimezone()); + ZonedDateTime zonedDateTime = ZonedDateTime.ofInstant(instant, zoneId); + // Convert to string format supported by Dameng + value = dtf.format(zonedDateTime); + } + } else if (fieldType instanceof ArrowType.Bool) { + if (value instanceof Boolean) { + value = ((Boolean) value) ? 1 : 0; + } + } else if (fieldType instanceof ArrowType.Interval intervalType) { + // INTERVAL type: use literal directly in SQL to avoid JDBC driver parsing errors + String intervalLiteral; + switch (intervalType.getUnit()){ + case YEAR_MONTH: + int totalMonths = 0; + boolean converted = false; + if (value instanceof Integer) { + totalMonths = (Integer) value; + converted = true; + } else if (value instanceof Long) { + totalMonths = ((Long) value).intValue(); + converted = true; + } else if (value instanceof String) { + try { + totalMonths = Integer.parseInt((String) value); + converted = true; + } catch (NumberFormatException e) { + log.warn("Failed to parse YEAR_MONTH string value: {}", value); + } + } + + if (converted) { + int years = totalMonths / 12; + int months = totalMonths % 12; + // Handle negative numbers: Java's % operator may return negative for negative numbers + if (totalMonths < 0 && months < 0) { + months = 12 + months; + years = years - 1; + } + if (totalMonths < 0) { + intervalLiteral = String.format("INTERVAL '-%d-%d' YEAR TO MONTH", Math.abs(years), Math.abs(months)); + } else { + intervalLiteral = String.format("INTERVAL '%d-%d' YEAR TO MONTH", years, months); + } + } else if (value instanceof Period period) { + // Period may not be normalized, need to normalize manually + int periodYears = period.getYears(); + int periodMonths = period.getMonths(); + int periodTotalMonths = periodYears * 12 + periodMonths; + int years = periodTotalMonths / 12; + int months = periodTotalMonths % 12; + // Handle negative numbers + if (periodTotalMonths < 0 && months < 0) { + months = 12 + months; + years = years - 1; + } + if (periodTotalMonths < 0) { + intervalLiteral = String.format("INTERVAL '-%d-%d' YEAR TO MONTH", Math.abs(years), Math.abs(months)); + } else { + intervalLiteral = String.format("INTERVAL '%d-%d' YEAR TO MONTH", years, months); + } + } else { + log.warn("Unexpected type for YEAR_MONTH interval, expected Integer/Long/String or Period but got {}. Setting to null.", + value != null ? value.getClass().getName() : "null"); + intervalLiteral = "NULL"; + } + break; + case DAY_TIME: + Duration duration = null; + if (value instanceof PeriodDuration pd) { + duration = pd.getDuration(); + } else if (value instanceof Duration) { + duration = (Duration) value; + } else { + log.warn("Unexpected type for DAY_TIME interval, expected PeriodDuration or Duration but got {}. Setting to null.", + value != null ? value.getClass().getName() : "null"); + intervalLiteral = "NULL"; + break; + } + + // Extract time components (unified handling of PeriodDuration and Duration) + long totalSeconds = duration.getSeconds(); + int nano = duration.getNano(); + + long days = totalSeconds / (24 * 60 * 60); + long hours = (totalSeconds % (24 * 60 * 60)) / (60 * 60); + long minutes = (totalSeconds % (60 * 60)) / 60; + long seconds = totalSeconds % 60; + long millis = nano / 1_000_000; + + intervalLiteral = String.format("INTERVAL '%d %02d:%02d:%02d.%03d' DAY TO SECOND(3)", + days, hours, minutes, seconds, millis); + break; + case MONTH_DAY_NANO: + // Dameng database does not support MONTH_DAY_NANO interval type + log.warn("Dameng database does not support MONTH_DAY_NANO interval type. Setting value to null."); + intervalLiteral = "NULL"; + break; + default: + log.warn("Unsupported interval unit: {}. Setting value to null.", intervalType.getUnit()); + intervalLiteral = "NULL"; + break; + } + // Use literal directly, do not add to parameter list + rowValues.add(intervalLiteral); + // Skip adding to params + continue; + } + + // General type conversion: handle Java types not supported by Dameng JDBC driver + if (value != null && (value instanceof Period || value instanceof Duration || value instanceof PeriodDuration)) { + log.warn("Unhandled temporal type: {}, setting to null", value.getClass().getName()); + value = null; + } + + // Non-INTERVAL types: use parameter placeholder + rowValues.add("?"); + params.add(value); + } + valueClauses.add("(" + String.join(", ", rowValues) + ")"); + } + + sb.append(String.join(", ", valueClauses)); + + return new DatabaseRecordWriter.SqlWithParams(sb.toString(), params); + } + + /** + * Check if table exists in the database. + * + * @param connection Database connection + * @param tableName Table name to check + * @return true if table exists, false otherwise + */ + public static boolean checkTableExists(Connection connection, String tableName) { + // Validate table name (prevent SQL injection) + if (!IDENTIFIER_PATTERN.matcher(tableName).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid tableName:" + tableName); + } + + /* + * Dameng SQL to check table existence - use user table view + * Note: Dameng database USER_TABLES.TABLE_NAME is uppercase, need to use UPPER() for case-insensitive comparison + * Table name has been validated by identifier pattern, safe to concatenate SQL + */ + String upperTableName = tableName.toUpperCase(); + String sql = "SELECT COUNT(*) FROM USER_TABLES WHERE UPPER(TABLE_NAME) = '" + upperTableName + "'"; + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(sql)) { + return rs.next() && rs.getInt(1) > 0; + } catch (SQLException e) { + log.error("Error checking table existence: {}", e.getMessage(), e); + throw new RuntimeException("Failed to check table existence: " + e.getMessage(), e); + } + } + + + /** + * Escape SQL string to prevent SQL injection. + * + * @param str String to escape + * @return Escaped string, or empty string if input is null + */ + private static String escapeString(String str) { + if (str == null) { + return ""; + } + return str.replace("'", "''"); + } + +} diff --git a/dataproxy-plugins/dataproxy-plugin-dameng/src/main/resources/META-INF/services/org.secretflow.dataproxy.core.spi.producer.DataProxyFlightProducer b/dataproxy-plugins/dataproxy-plugin-dameng/src/main/resources/META-INF/services/org.secretflow.dataproxy.core.spi.producer.DataProxyFlightProducer new file mode 100644 index 0000000..7da6102 --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-dameng/src/main/resources/META-INF/services/org.secretflow.dataproxy.core.spi.producer.DataProxyFlightProducer @@ -0,0 +1 @@ +org.secretflow.dataproxy.producer.DamengFlightProducer \ No newline at end of file diff --git a/dataproxy-plugins/dataproxy-plugin-dameng/src/test/java/org/secretflow/dataproxy/producer/DamengFlightProducerTest.java b/dataproxy-plugins/dataproxy-plugin-dameng/src/test/java/org/secretflow/dataproxy/producer/DamengFlightProducerTest.java new file mode 100644 index 0000000..9771f36 --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-dameng/src/test/java/org/secretflow/dataproxy/producer/DamengFlightProducerTest.java @@ -0,0 +1,228 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.secretflow.dataproxy.producer; + +import com.google.protobuf.Any; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.secretflow.v1alpha1.kusciaapi.Domaindata; +import org.secretflow.v1alpha1.kusciaapi.Domaindatasource; +import org.secretflow.v1alpha1.kusciaapi.Flightdm; +import org.secretflow.v1alpha1.kusciaapi.Flightinner; + +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.when; + + +/** + * @author kongxiaoran + * @date 2025/11/07 + */ +@ExtendWith(MockitoExtension.class) +public class DamengFlightProducerTest { + + @InjectMocks + private DamengFlightProducer damengFlightProducer; + + @Mock + private FlightProducer.CallContext context; + + @Mock + private FlightDescriptor descriptor; + + private final Domaindatasource.DatabaseDataSourceInfo damengDataSourceInfo = + Domaindatasource.DatabaseDataSourceInfo + .newBuilder() + .setEndpoint("jdbc:dm://localhost:5236") + .setDatabase("database") + .setUser("user") + .setPassword("password") + .build(); + private final Domaindatasource.DataSourceInfo dataSourceInfo = + Domaindatasource.DataSourceInfo.newBuilder().setDatabase(damengDataSourceInfo).build(); + + private final Domaindatasource.DomainDataSource domainDataSource = + Domaindatasource.DomainDataSource.newBuilder() + .setDatasourceId("datasourceId") + .setName("datasourceName") + .setType("dameng") + .setInfo(dataSourceInfo) + .build(); + + private final Domaindata.DomainData domainData = + Domaindata.DomainData.newBuilder() + .setDatasourceId("datasourceId") + .setName("domainDataName") + .setRelativeUri("table_name") + .setDomaindataId("domainDataId") + .setType("table") + .build(); + + private final Flightdm.CommandDomainDataQuery commandDomainDataQueryWithCSV = + Flightdm.CommandDomainDataQuery.newBuilder() + .setContentType(Flightdm.ContentType.CSV) + .build(); + + @Test + public void testGetProducerName() { + String producerName = damengFlightProducer.getProducerName(); + assertEquals("dameng", producerName); + } + + @Test + public void testGetFlightInfoWithTableCommand() { + Flightinner.CommandDataMeshQuery dataMeshQuery = + Flightinner.CommandDataMeshQuery.newBuilder() + .setQuery(commandDomainDataQueryWithCSV) + .setDatasource(domainDataSource) + .setDomaindata(domainData) + .build(); + when(descriptor.getCommand()).thenReturn(Any.pack(dataMeshQuery).toByteArray()); + + assertDoesNotThrow(() -> { + FlightInfo flightInfo = damengFlightProducer.getFlightInfo(context, descriptor); + + assertNotNull(flightInfo); + assertFalse(flightInfo.getEndpoints().isEmpty()); + + assertNotNull(flightInfo.getEndpoints().get(0).getLocations()); + assertFalse(flightInfo.getEndpoints().get(0).getLocations().isEmpty()); + + assertNotNull(flightInfo.getEndpoints().get(0).getTicket()); + assertNotNull(flightInfo.getEndpoints().get(0).getTicket().getBytes()); + + }); + } + + @Test + public void testGetFlightInfoWithUnsupportedType() { + when(descriptor.getCommand()).thenReturn("testCommand".getBytes(StandardCharsets.UTF_8)); + assertThrows(RuntimeException.class, () -> damengFlightProducer.getFlightInfo(context, descriptor)); + } + + @Test + public void testGetFlightInfoWithSqlQuery() { + Flightdm.CommandDataSourceSqlQuery commandDataSourceSqlQuery = + Flightdm.CommandDataSourceSqlQuery.newBuilder() + .setSql("SELECT * FROM test_table") + .setDatasourceId("datasourceId") + .build(); + + Flightinner.CommandDataMeshSqlQuery sqlQuery = + Flightinner.CommandDataMeshSqlQuery.newBuilder() + .setQuery(commandDataSourceSqlQuery) + .setDatasource(domainDataSource) + .build(); + + when(descriptor.getCommand()).thenReturn(Any.pack(sqlQuery).toByteArray()); + + assertDoesNotThrow(() -> { + FlightInfo flightInfo = damengFlightProducer.getFlightInfo(context, descriptor); + assertNotNull(flightInfo); + assertFalse(flightInfo.getEndpoints().isEmpty()); + }); + } + + @Test + public void testGetFlightInfoWithUpdateCommand() { + Flightdm.CommandDomainDataUpdate commandDomainDataUpdate = + Flightdm.CommandDomainDataUpdate.newBuilder() + .setContentType(Flightdm.ContentType.CSV) + .build(); + + Flightinner.CommandDataMeshUpdate updateCommand = + Flightinner.CommandDataMeshUpdate.newBuilder() + .setUpdate(commandDomainDataUpdate) + .setDatasource(domainDataSource) + .setDomaindata(domainData) + .build(); + + when(descriptor.getCommand()).thenReturn(Any.pack(updateCommand).toByteArray()); + + assertDoesNotThrow(() -> { + FlightInfo flightInfo = damengFlightProducer.getFlightInfo(context, descriptor); + assertNotNull(flightInfo); + assertFalse(flightInfo.getEndpoints().isEmpty()); + }); + } + + @Test + public void testGetFlightInfoWithRawContentType() { + Flightdm.CommandDomainDataQuery commandDomainDataQueryWithRaw = + Flightdm.CommandDomainDataQuery.newBuilder() + .setContentType(Flightdm.ContentType.RAW) + .build(); + + Flightinner.CommandDataMeshQuery dataMeshQuery = + Flightinner.CommandDataMeshQuery.newBuilder() + .setQuery(commandDomainDataQueryWithRaw) + .setDatasource(domainDataSource) + .setDomaindata(domainData) + .build(); + + when(descriptor.getCommand()).thenReturn(Any.pack(dataMeshQuery).toByteArray()); + + assertDoesNotThrow(() -> { + FlightInfo flightInfo = damengFlightProducer.getFlightInfo(context, descriptor); + assertNotNull(flightInfo); + assertFalse(flightInfo.getEndpoints().isEmpty()); + }); + } + + @Test + public void testGetFlightInfoWithPartitionSpec() { + Flightdm.CommandDomainDataQuery commandDomainDataQuery = + Flightdm.CommandDomainDataQuery.newBuilder() + .setContentType(Flightdm.ContentType.CSV) + .setPartitionSpec("dt=20240101") + .build(); + + Flightinner.CommandDataMeshQuery dataMeshQuery = + Flightinner.CommandDataMeshQuery.newBuilder() + .setQuery(commandDomainDataQuery) + .setDatasource(domainDataSource) + .setDomaindata(domainData) + .build(); + + when(descriptor.getCommand()).thenReturn(Any.pack(dataMeshQuery).toByteArray()); + + assertDoesNotThrow(() -> { + FlightInfo flightInfo = damengFlightProducer.getFlightInfo(context, descriptor); + assertNotNull(flightInfo); + assertFalse(flightInfo.getEndpoints().isEmpty()); + }); + } + + @Test + public void testGetFlightInfoWithInvalidCommandType() { + when(descriptor.getCommand()).thenReturn( + Any.pack(Flightdm.CommandDomainDataQuery.newBuilder().build()) + .toByteArray()); + + assertThrows(RuntimeException.class, () -> { + damengFlightProducer.getFlightInfo(context, descriptor); + }); + } +} diff --git a/dataproxy-plugins/dataproxy-plugin-dameng/src/test/java/org/secretflow/dataproxy/util/DamengUtilTest.java b/dataproxy-plugins/dataproxy-plugin-dameng/src/test/java/org/secretflow/dataproxy/util/DamengUtilTest.java new file mode 100644 index 0000000..507da29 --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-dameng/src/test/java/org/secretflow/dataproxy/util/DamengUtilTest.java @@ -0,0 +1,806 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.secretflow.dataproxy.util; + +import org.apache.arrow.vector.PeriodDuration; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.secretflow.dataproxy.common.exceptions.DataproxyException; +import org.secretflow.dataproxy.plugin.database.config.DatabaseConnectConfig; +import org.secretflow.dataproxy.plugin.database.writer.DatabaseRecordWriter; + +import java.sql.*; +import java.time.Duration; +import java.time.Period; +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +/** + * Unit test class for DamengUtil. + * + * @author kongxiaoran + * @date 2025/11/07 + */ +@ExtendWith(MockitoExtension.class) +public class DamengUtilTest { + + @Mock + private Connection mockConnection; + + @Mock + private Statement mockStatement; + + @Mock + private ResultSet mockResultSet; + + @Mock + private DatabaseMetaData mockDatabaseMetaData; + + private DatabaseConnectConfig validConfig; + + @BeforeEach + public void setUp() { + validConfig = new DatabaseConnectConfig( + "testuser", + "testpass", + "localhost:5236", + "TESTDB" + ); + } + + @Test + public void testInitDameng_ValidConfig() { + // Note: This test requires real JDBC driver, may not run in unit tests + // Actual tests should be in integration tests + // Here mainly tests parameter validation logic + assertThrows(RuntimeException.class, () -> { + DamengUtil.initDameng(validConfig); + }); + } + + @Test + public void testInitDameng_InvalidEndpoint_Empty() { + DatabaseConnectConfig config = new DatabaseConnectConfig( + "user", "pass", "", "DB" + ); + assertThrows(DataproxyException.class, () -> { + DamengUtil.initDameng(config); + }); + } + + @Test + public void testInitDameng_InvalidEndpoint_InvalidIP() { + DatabaseConnectConfig config = new DatabaseConnectConfig( + "user", "pass", "invalid..ip:5236", "DB" + ); + assertThrows(DataproxyException.class, () -> { + DamengUtil.initDameng(config); + }); + } + + @Test + public void testInitDameng_InvalidPort_OutOfRange() { + DatabaseConnectConfig config = new DatabaseConnectConfig( + "user", "pass", "localhost:70000", "DB" + ); + assertThrows(DataproxyException.class, () -> { + DamengUtil.initDameng(config); + }); + } + + @Test + public void testInitDameng_InvalidDatabaseName() { + DatabaseConnectConfig config = new DatabaseConnectConfig( + "user", "pass", "localhost:5236", "invalid-db-name!" + ); + assertThrows(DataproxyException.class, () -> { + DamengUtil.initDameng(config); + }); + } + + @Test + public void testInitDameng_EndpointWithoutPort() { + DatabaseConnectConfig config = new DatabaseConnectConfig( + "user", "pass", "localhost", "DB" + ); + // Should use default port 5236 + assertThrows(RuntimeException.class, () -> { + DamengUtil.initDameng(config); + }); + } + + + @Test + public void testBuildQuerySql_SimpleQuery() { + List columns = Arrays.asList("col1", "col2", "col3"); + String sql = DamengUtil.buildQuerySql("test_table", columns, null); + + assertEquals("SELECT col1, col2, col3 FROM test_table", sql); + } + + @Test + public void testBuildQuerySql_WithPartition() { + List columns = Arrays.asList("col1", "col2"); + String partition = "dt=20240101"; + String sql = DamengUtil.buildQuerySql("test_table", columns, partition); + + assertTrue(sql.contains("SELECT col1, col2 FROM test_table")); + assertTrue(sql.contains("WHERE")); + assertTrue(sql.contains("dt='20240101'")); + } + + @Test + public void testBuildQuerySql_WithMultiplePartitions() { + List columns = Arrays.asList("col1"); + String partition = "dt=20240101,region=us"; + String sql = DamengUtil.buildQuerySql("test_table", columns, partition); + + assertTrue(sql.contains("WHERE")); + assertTrue(sql.contains("dt='20240101'")); + assertTrue(sql.contains("region='us'")); + assertTrue(sql.contains("AND")); + } + + @Test + public void testBuildQuerySql_EmptyColumns() { + assertThrows(DataproxyException.class, () -> { + DamengUtil.buildQuerySql("test_table", Collections.emptyList(), null); + }); + } + + @Test + public void testBuildQuerySql_InvalidTableName() { + List columns = Arrays.asList("col1"); + assertThrows(DataproxyException.class, () -> { + DamengUtil.buildQuerySql("invalid-table-name!", columns, null); + }); + } + + @Test + public void testBuildQuerySql_InvalidColumnName() { + List columns = Arrays.asList("col1", "invalid-column!"); + assertThrows(DataproxyException.class, () -> { + DamengUtil.buildQuerySql("test_table", columns, null); + }); + } + + @Test + public void testBuildQuerySql_InvalidPartitionKey() { + List columns = Arrays.asList("col1"); + String partition = "invalid-key!=value"; + assertThrows(DataproxyException.class, () -> { + DamengUtil.buildQuerySql("test_table", columns, partition); + }); + } + + @Test + public void testBuildQuerySql_PartitionValueWithSpecialChars() { + List columns = Arrays.asList("col1"); + String partition = "dt=2024-01-01"; + String sql = DamengUtil.buildQuerySql("test_table", columns, partition); + assertTrue(sql.contains("dt='2024-01-01'")); + } + + @Test + public void testBuildQuerySql_PartitionValueWithSingleQuote() { + List columns = Arrays.asList("col1"); + String partition = "dt=test'value"; + String sql = DamengUtil.buildQuerySql("test_table", columns, partition); + // Should escape single quotes + assertTrue(sql.contains("dt='test''value'")); + } + + + @Test + public void testBuildCreateTableSql_SimpleTable() { + Schema schema = new Schema(Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("name", FieldType.nullable(ArrowType.Utf8.INSTANCE), null) + )); + + String sql = DamengUtil.buildCreateTableSql("test_table", schema, null); + + assertTrue(sql.contains("CREATE TABLE test_table")); + assertTrue(sql.contains("id INT")); + assertTrue(sql.contains("name VARCHAR")); + } + + @Test + public void testBuildCreateTableSql_WithPartition() { + Schema schema = new Schema(Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("dt", FieldType.nullable(ArrowType.Utf8.INSTANCE), null) + )); + + Map partition = Map.of("dt", "20240101"); + String sql = DamengUtil.buildCreateTableSql("test_table", schema, partition); + + assertTrue(sql.contains("CREATE TABLE test_table")); + assertTrue(sql.contains("dt VARCHAR")); + } + + @Test + public void testBuildCreateTableSql_EmptySchema() { + Schema schema = new Schema(Collections.emptyList()); + + assertThrows(DataproxyException.class, () -> { + DamengUtil.buildCreateTableSql("test_table", schema, null); + }); + } + + @Test + public void testBuildCreateTableSql_InvalidTableName() { + Schema schema = new Schema(Arrays.asList( + new Field("col1", FieldType.nullable(ArrowType.Utf8.INSTANCE), null) + )); + + assertThrows(DataproxyException.class, () -> { + DamengUtil.buildCreateTableSql("invalid-table!", schema, null); + }); + } + + @Test + public void testBuildCreateTableSql_InvalidFieldName() { + Schema schema = new Schema(Arrays.asList( + new Field("invalid-field!", FieldType.nullable(ArrowType.Utf8.INSTANCE), null) + )); + + assertThrows(DataproxyException.class, () -> { + DamengUtil.buildCreateTableSql("test_table", schema, null); + }); + } + + + @Test + public void testArrowTypeToJdbcType_StringTypes() { + assertEquals("VARCHAR", DamengUtil.arrowTypeToJdbcType(ArrowType.Utf8.INSTANCE)); + assertEquals("CLOB", DamengUtil.arrowTypeToJdbcType(ArrowType.LargeUtf8.INSTANCE)); + } + + @Test + public void testArrowTypeToJdbcType_IntegerTypes() { + assertEquals("TINYINT", DamengUtil.arrowTypeToJdbcType(new ArrowType.Int(8, true))); + assertEquals("SMALLINT", DamengUtil.arrowTypeToJdbcType(new ArrowType.Int(16, true))); + assertEquals("INT", DamengUtil.arrowTypeToJdbcType(new ArrowType.Int(32, true))); + assertEquals("BIGINT", DamengUtil.arrowTypeToJdbcType(new ArrowType.Int(64, true))); + } + + @Test + public void testArrowTypeToJdbcType_InvalidIntBitWidth() { + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.arrowTypeToJdbcType(new ArrowType.Int(128, true)); + }); + } + + @Test + public void testArrowTypeToJdbcType_FloatingPointTypes() { + assertEquals("FLOAT", DamengUtil.arrowTypeToJdbcType( + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); + assertEquals("DOUBLE", DamengUtil.arrowTypeToJdbcType( + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))); + } + + @Test + public void testArrowTypeToJdbcType_Bool() { + assertEquals("BIT", DamengUtil.arrowTypeToJdbcType(ArrowType.Bool.INSTANCE)); + } + + @Test + public void testArrowTypeToJdbcType_DateTypes() { + assertEquals("DATE", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Date(DateUnit.DAY))); + assertEquals("DATETIME(3)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Date(DateUnit.MILLISECOND))); + } + + @Test + public void testArrowTypeToJdbcType_TimeTypes() { + assertEquals("TIME(0)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Time(TimeUnit.SECOND, 32))); + assertEquals("TIME(3)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Time(TimeUnit.MILLISECOND, 32))); + assertEquals("TIME(6)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Time(TimeUnit.MICROSECOND, 64))); + assertEquals("TIME(6)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Time(TimeUnit.NANOSECOND, 64))); + } + + @Test + public void testArrowTypeToJdbcType_TimestampTypes() { + assertEquals("TIMESTAMP(0)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Timestamp(TimeUnit.SECOND, null))); + assertEquals("TIMESTAMP(3)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Timestamp(TimeUnit.MILLISECOND, null))); + assertEquals("TIMESTAMP(6)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))); + } + + @Test + public void testArrowTypeToJdbcType_TimestampWithTimezone() { + ArrowType.Timestamp ts = new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC"); + String result = DamengUtil.arrowTypeToJdbcType(ts); + assertTrue(result.contains("TIMESTAMP WITH TIME ZONE")); + } + + @Test + public void testArrowTypeToJdbcType_TimestampNanosecond() { + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.arrowTypeToJdbcType(new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)); + }); + } + + @Test + public void testArrowTypeToJdbcType_Decimal() { + assertEquals("DECIMAL(10, 2)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Decimal(10, 2, 128))); + assertEquals("DECIMAL(38, 10)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Decimal(38, 10, 128))); + } + + @Test + public void testArrowTypeToJdbcType_Decimal_InvalidPrecision() { + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.arrowTypeToJdbcType(new ArrowType.Decimal(0, 0, 128)); + }); + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.arrowTypeToJdbcType(new ArrowType.Decimal(39, 10, 128)); + }); + } + + @Test + public void testArrowTypeToJdbcType_Decimal_InvalidScale() { + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.arrowTypeToJdbcType(new ArrowType.Decimal(10, -1, 128)); + }); + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.arrowTypeToJdbcType(new ArrowType.Decimal(10, 11, 128)); + }); + } + + @Test + public void testArrowTypeToJdbcType_BinaryTypes() { + assertEquals("VARBINARY", DamengUtil.arrowTypeToJdbcType(ArrowType.Binary.INSTANCE)); + assertEquals("BLOB", DamengUtil.arrowTypeToJdbcType(ArrowType.LargeBinary.INSTANCE)); + } + + @Test + public void testArrowTypeToJdbcType_FixedSizeBinary() { + assertEquals("BINARY(10)", DamengUtil.arrowTypeToJdbcType( + new ArrowType.FixedSizeBinary(10))); + assertEquals("BLOB", DamengUtil.arrowTypeToJdbcType( + new ArrowType.FixedSizeBinary(10000))); + } + + @Test + public void testArrowTypeToJdbcType_IntervalTypes() { + assertEquals("INTERVAL YEAR TO MONTH", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Interval(IntervalUnit.YEAR_MONTH))); + assertEquals("INTERVAL DAY TO SECOND", DamengUtil.arrowTypeToJdbcType( + new ArrowType.Interval(IntervalUnit.DAY_TIME))); + } + + @Test + public void testArrowTypeToJdbcType_Null() { + assertEquals("VARCHAR(1)", DamengUtil.arrowTypeToJdbcType(ArrowType.Null.INSTANCE)); + } + + @Test + public void testArrowTypeToJdbcType_UnsupportedType() { + // Cannot create unsupported ArrowType instance (ArrowType is final) + // This test mainly verifies the else branch in code + } + + @Test + public void testJdbcType2ArrowType_DecimalTypes() { + ArrowType type1 = DamengUtil.jdbcType2ArrowType("DECIMAL"); + assertInstanceOf(ArrowType.Decimal.class, type1); + + ArrowType type2 = DamengUtil.jdbcType2ArrowType("NUMERIC"); + assertInstanceOf(ArrowType.Decimal.class, type2); + + ArrowType type3 = DamengUtil.jdbcType2ArrowType("NUMBER"); + assertInstanceOf(ArrowType.Decimal.class, type3); + } + + @Test + public void testJdbcType2ArrowType_TimestampTypes() { + ArrowType type1 = DamengUtil.jdbcType2ArrowType("TIMESTAMP(3)"); + assertInstanceOf(ArrowType.Timestamp.class, type1); + assertEquals(TimeUnit.MILLISECOND, ((ArrowType.Timestamp) type1).getUnit()); + + ArrowType type2 = DamengUtil.jdbcType2ArrowType("TIMESTAMP(6)"); + assertInstanceOf(ArrowType.Timestamp.class, type2); + assertEquals(TimeUnit.MICROSECOND, ((ArrowType.Timestamp) type2).getUnit()); + } + + @Test + public void testJdbcType2ArrowType_DatetimeTypes() { + ArrowType type1 = DamengUtil.jdbcType2ArrowType("DATETIME(3)"); + assertInstanceOf(ArrowType.Date.class, type1); + + ArrowType type2 = DamengUtil.jdbcType2ArrowType("DATETIME(6)"); + assertInstanceOf(ArrowType.Timestamp.class, type2); + } + + @Test + public void testJdbcType2ArrowType_TimeTypes() { + ArrowType type1 = DamengUtil.jdbcType2ArrowType("TIME(3)"); + assertInstanceOf(ArrowType.Time.class, type1); + ArrowType.Time time1 = (ArrowType.Time) type1; + assertEquals(TimeUnit.MILLISECOND, time1.getUnit()); + assertEquals(32, time1.getBitWidth()); + + ArrowType type2 = DamengUtil.jdbcType2ArrowType("TIME(6)"); + assertInstanceOf(ArrowType.Time.class, type2); + ArrowType.Time time2 = (ArrowType.Time) type2; + assertEquals(TimeUnit.MICROSECOND, time2.getUnit()); + assertEquals(64, time2.getBitWidth()); + } + + @Test + public void testJdbcType2ArrowType_TimestampWithTimezone() { + ArrowType type1 = DamengUtil.jdbcType2ArrowType("TIMESTAMP WITH TIME ZONE"); + assertInstanceOf(ArrowType.Timestamp.class, type1); + assertEquals("UTC", ((ArrowType.Timestamp) type1).getTimezone()); + } + + @Test + public void testJdbcType2ArrowType_StringTypes() { + assertEquals(ArrowType.Utf8.INSTANCE, DamengUtil.jdbcType2ArrowType("VARCHAR")); + assertEquals(ArrowType.Utf8.INSTANCE, DamengUtil.jdbcType2ArrowType("CHAR")); + assertEquals(ArrowType.LargeUtf8.INSTANCE, DamengUtil.jdbcType2ArrowType("CLOB")); + assertEquals(ArrowType.LargeUtf8.INSTANCE, DamengUtil.jdbcType2ArrowType("TEXT")); + } + + @Test + public void testJdbcType2ArrowType_IntegerTypes() { + ArrowType type1 = DamengUtil.jdbcType2ArrowType("TINYINT"); + assertInstanceOf(ArrowType.Int.class, type1); + assertEquals(8, ((ArrowType.Int) type1).getBitWidth()); + + ArrowType type2 = DamengUtil.jdbcType2ArrowType("SMALLINT"); + assertInstanceOf(ArrowType.Int.class, type2); + assertEquals(16, ((ArrowType.Int) type2).getBitWidth()); + + ArrowType type3 = DamengUtil.jdbcType2ArrowType("INT"); + assertInstanceOf(ArrowType.Int.class, type3); + assertEquals(32, ((ArrowType.Int) type3).getBitWidth()); + + ArrowType type4 = DamengUtil.jdbcType2ArrowType("BIGINT"); + assertInstanceOf(ArrowType.Int.class, type4); + assertEquals(64, ((ArrowType.Int) type4).getBitWidth()); + } + + @Test + public void testJdbcType2ArrowType_FloatingPointTypes() { + ArrowType type1 = DamengUtil.jdbcType2ArrowType("FLOAT"); + assertInstanceOf(ArrowType.FloatingPoint.class, type1); + assertEquals(FloatingPointPrecision.SINGLE, + ((ArrowType.FloatingPoint) type1).getPrecision()); + + ArrowType type2 = DamengUtil.jdbcType2ArrowType("DOUBLE"); + assertInstanceOf(ArrowType.FloatingPoint.class, type2); + assertEquals(FloatingPointPrecision.DOUBLE, + ((ArrowType.FloatingPoint) type2).getPrecision()); + } + + @Test + public void testJdbcType2ArrowType_Bool() { + assertEquals(ArrowType.Bool.INSTANCE, DamengUtil.jdbcType2ArrowType("BIT")); + assertEquals(ArrowType.Bool.INSTANCE, DamengUtil.jdbcType2ArrowType("BOOLEAN")); + } + + @Test + public void testJdbcType2ArrowType_Date() { + ArrowType type = DamengUtil.jdbcType2ArrowType("DATE"); + assertInstanceOf(ArrowType.Date.class, type); + assertEquals(DateUnit.DAY, ((ArrowType.Date) type).getUnit()); + } + + @Test + public void testJdbcType2ArrowType_BinaryTypes() { + assertEquals(ArrowType.Binary.INSTANCE, DamengUtil.jdbcType2ArrowType("VARBINARY")); + assertEquals(ArrowType.Binary.INSTANCE, DamengUtil.jdbcType2ArrowType("BINARY")); + assertEquals(ArrowType.LargeBinary.INSTANCE, DamengUtil.jdbcType2ArrowType("BLOB")); + } + + @Test + public void testJdbcType2ArrowType_IntervalTypes() { + ArrowType type1 = DamengUtil.jdbcType2ArrowType("INTERVAL YEAR TO MONTH"); + assertInstanceOf(ArrowType.Interval.class, type1); + assertEquals(IntervalUnit.YEAR_MONTH, ((ArrowType.Interval) type1).getUnit()); + + ArrowType type2 = DamengUtil.jdbcType2ArrowType("INTERVAL DAY TO SECOND"); + assertInstanceOf(ArrowType.Interval.class, type2); + assertEquals(IntervalUnit.DAY_TIME, ((ArrowType.Interval) type2).getUnit()); + } + + @Test + public void testJdbcType2ArrowType_Null() { + assertEquals(ArrowType.Null.INSTANCE, DamengUtil.jdbcType2ArrowType("NULL")); + } + + @Test + public void testJdbcType2ArrowType_UnsupportedType() { + // Unsupported types should return Utf8 as fallback + ArrowType type = DamengUtil.jdbcType2ArrowType("UNKNOWN_TYPE"); + assertEquals(ArrowType.Utf8.INSTANCE, type); + } + + @Test + public void testJdbcType2ArrowType_NullInput() { + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.jdbcType2ArrowType(null); + }); + + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.jdbcType2ArrowType(""); + }); + } + + + @Test + public void testBuildMultiRowInsertSql_SimpleInsert() { + Schema schema = new Schema(Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("name", FieldType.nullable(ArrowType.Utf8.INSTANCE), null) + )); + + List> dataList = new ArrayList<>(); + Map row1 = new HashMap<>(); + row1.put("id", 1); + row1.put("name", "test1"); + dataList.add(row1); + + Map partition = Collections.emptyMap(); + DatabaseRecordWriter.SqlWithParams result = DamengUtil.buildMultiRowInsertSql( + "test_table", schema, dataList, partition); + + assertNotNull(result); + assertNotNull(result.sql); + assertTrue(result.sql.contains("INSERT INTO test_table")); + assertTrue(result.sql.contains("id")); + assertTrue(result.sql.contains("name")); + assertNotNull(result.params); + assertEquals(2, result.params.size()); + } + + @Test + public void testBuildMultiRowInsertSql_MultipleRows() { + Schema schema = new Schema(Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null) + )); + + List> dataList = new ArrayList<>(); + for (int i = 1; i <= 3; i++) { + Map row = new HashMap<>(); + row.put("id", i); + dataList.add(row); + } + + DatabaseRecordWriter.SqlWithParams result = DamengUtil.buildMultiRowInsertSql( + "test_table", schema, dataList, Collections.emptyMap()); + + assertNotNull(result); + assertTrue(result.sql.contains("VALUES")); + assertEquals(3, result.params.size()); + } + + @Test + public void testBuildMultiRowInsertSql_WithPartition() { + Schema schema = new Schema(Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("dt", FieldType.nullable(ArrowType.Utf8.INSTANCE), null) + )); + + List> dataList = new ArrayList<>(); + Map row = new HashMap<>(); + row.put("id", 1); + dataList.add(row); + + Map partition = Map.of("dt", "20240101"); + DatabaseRecordWriter.SqlWithParams result = DamengUtil.buildMultiRowInsertSql( + "test_table", schema, dataList, partition); + + assertNotNull(result); + assertTrue(result.sql.contains("dt")); + } + + @Test + public void testBuildMultiRowInsertSql_WithBool() { + Schema schema = new Schema(Arrays.asList( + new Field("flag", FieldType.nullable(ArrowType.Bool.INSTANCE), null) + )); + + List> dataList = new ArrayList<>(); + Map row = new HashMap<>(); + row.put("flag", true); + dataList.add(row); + + DatabaseRecordWriter.SqlWithParams result = DamengUtil.buildMultiRowInsertSql( + "test_table", schema, dataList, Collections.emptyMap()); + + assertNotNull(result); + assertEquals(1, result.params.size()); + assertEquals(1, result.params.get(0)); // true -> 1 + } + + @Test + public void testBuildMultiRowInsertSql_WithTimestamp() { + Schema schema = new Schema(Arrays.asList( + new Field("ts", FieldType.nullable( + new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), null) + )); + + List> dataList = new ArrayList<>(); + Map row = new HashMap<>(); + row.put("ts", System.currentTimeMillis()); + dataList.add(row); + + DatabaseRecordWriter.SqlWithParams result = DamengUtil.buildMultiRowInsertSql( + "test_table", schema, dataList, Collections.emptyMap()); + + assertNotNull(result); + assertEquals(1, result.params.size()); + } + + @Test + public void testBuildMultiRowInsertSql_WithIntervalYearMonth() { + Schema schema = new Schema(Arrays.asList( + new Field("interval_col", FieldType.nullable( + new ArrowType.Interval(IntervalUnit.YEAR_MONTH)), null) + )); + + List> dataList = new ArrayList<>(); + Map row = new HashMap<>(); + row.put("interval_col", 14); // 14 months = 1 year 2 months + dataList.add(row); + + DatabaseRecordWriter.SqlWithParams result = DamengUtil.buildMultiRowInsertSql( + "test_table", schema, dataList, Collections.emptyMap()); + + assertNotNull(result); + assertTrue(result.sql.contains("INTERVAL")); + assertTrue(result.sql.contains("YEAR TO MONTH")); + } + + @Test + public void testBuildMultiRowInsertSql_WithIntervalDayTime() { + Schema schema = new Schema(Arrays.asList( + new Field("interval_col", FieldType.nullable( + new ArrowType.Interval(IntervalUnit.DAY_TIME)), null) + )); + + List> dataList = new ArrayList<>(); + Map row = new HashMap<>(); + PeriodDuration pd = new PeriodDuration(Period.ZERO, Duration.ofHours(2)); + row.put("interval_col", pd); + dataList.add(row); + + DatabaseRecordWriter.SqlWithParams result = DamengUtil.buildMultiRowInsertSql( + "test_table", schema, dataList, Collections.emptyMap()); + + assertNotNull(result); + assertTrue(result.sql.contains("INTERVAL")); + assertTrue(result.sql.contains("DAY TO SECOND")); + } + + @Test + public void testBuildMultiRowInsertSql_EmptyDataList() { + Schema schema = new Schema(Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null) + )); + + assertThrows(IllegalArgumentException.class, () -> { + DamengUtil.buildMultiRowInsertSql("test_table", schema, + Collections.emptyList(), Collections.emptyMap()); + }); + } + + @Test + public void testBuildMultiRowInsertSql_InvalidTableName() { + Schema schema = new Schema(Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null) + )); + + List> dataList = new ArrayList<>(); + Map row = new HashMap<>(); + row.put("id", 1); + dataList.add(row); + + assertThrows(DataproxyException.class, () -> { + DamengUtil.buildMultiRowInsertSql("invalid-table!", schema, + dataList, Collections.emptyMap()); + }); + } + + + @Test + public void testCheckTableExists_TableExists() throws SQLException { + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getInt(1)).thenReturn(1); + + boolean exists = DamengUtil.checkTableExists(mockConnection, "TEST_TABLE"); + + assertTrue(exists); + verify(mockStatement).executeQuery(contains("SELECT COUNT(*) FROM USER_TABLES")); + } + + @Test + public void testCheckTableExists_TableNotExists() throws SQLException { + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getInt(1)).thenReturn(0); + + boolean exists = DamengUtil.checkTableExists(mockConnection, "NONEXISTENT_TABLE"); + + assertFalse(exists); + } + + @Test + public void testCheckTableExists_InvalidTableName() { + assertThrows(DataproxyException.class, () -> { + DamengUtil.checkTableExists(mockConnection, "invalid-table!"); + }); + } + + @Test + public void testCheckTableExists_SQLException() throws SQLException { + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenThrow(new SQLException("Database error")); + + assertThrows(RuntimeException.class, () -> { + DamengUtil.checkTableExists(mockConnection, "TEST_TABLE"); + }); + } + + @Test + public void testCheckTableExists_TableNameCaseInsensitive() throws SQLException { + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getInt(1)).thenReturn(1); + + boolean exists1 = DamengUtil.checkTableExists(mockConnection, "test_table"); + assertTrue(exists1); + + boolean exists2 = DamengUtil.checkTableExists(mockConnection, "TEST_TABLE"); + assertTrue(exists2); + + // Verify SQL uses UPPER() + verify(mockStatement, atLeastOnce()).executeQuery(argThat(sql -> + sql.contains("UPPER(TABLE_NAME)"))); + } +} + diff --git a/dataproxy-plugins/dataproxy-plugin-database/pom.xml b/dataproxy-plugins/dataproxy-plugin-database/pom.xml index 0be9f41..9d5b867 100644 --- a/dataproxy-plugins/dataproxy-plugin-database/pom.xml +++ b/dataproxy-plugins/dataproxy-plugin-database/pom.xml @@ -20,6 +20,12 @@ lombok + + + org.secretflow + dataproxy-core + + org.junit.jupiter @@ -47,6 +53,30 @@ org.apache.arrow arrow-memory-core + + + + org.apache.arrow + arrow-vector + + + org.apache.arrow + arrow-format + + + + + org.apache.arrow + arrow-memory-netty + test + + + + + org.slf4j + slf4j-simple + test + diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/producer/AbstractDatabaseFlightProducer.java b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/producer/AbstractDatabaseFlightProducer.java index c71d172..d9adc4a 100644 --- a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/producer/AbstractDatabaseFlightProducer.java +++ b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/producer/AbstractDatabaseFlightProducer.java @@ -197,6 +197,7 @@ public Runnable acceptPut( } count += rowCount; } + writer.flush(); ackStream.onCompleted(); writer.close(); log.info("put data over! all count: {}", count); diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetContext.java b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetContext.java index 9ed5d1b..5865967 100644 --- a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetContext.java +++ b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetContext.java @@ -22,31 +22,59 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; -import org.secretflow.dataproxy.common.utils.JsonUtils; +import org.secretflow.dataproxy.common.exceptions.DataproxyErrorCode; +import org.secretflow.dataproxy.common.exceptions.DataproxyException; import org.secretflow.dataproxy.core.param.ParamWrapper; -import org.secretflow.dataproxy.plugin.database.config.DatabaseCommandConfig; import org.secretflow.dataproxy.plugin.database.config.*; - import org.secretflow.dataproxy.plugin.database.constant.DatabaseTypeEnum; -import org.secretflow.dataproxy.common.exceptions.DataproxyErrorCode; -import org.secretflow.dataproxy.common.exceptions.DataproxyException; import org.secretflow.v1alpha1.common.Common; import java.sql.*; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.function.BiFunction; import java.util.function.Function; @Slf4j public class DatabaseDoGetContext { + // DECIMAL type aliases + private static final Set DECIMAL_TYPE_NAMES = Set.of("DECIMAL", "NUMERIC", "NUMBER", "DEC"); + + // Time types that require precision information + private static final Set TIME_TYPE_NAMES = Set.of("DATETIME", "TIMESTAMP", "TIME"); + + // Default precision and scale for DECIMAL type + private static final int DEFAULT_DECIMAL_PRECISION = 38; + private static final int DEFAULT_DECIMAL_SCALE = 10; + + /** + * SQL integer supplier functional interface for simplified exception handling. + */ + @FunctionalInterface + private interface SqlIntSupplier { + int getAsInt() throws SQLException; + } + + /** + * Column information inner class for unified handling of column metadata from different sources. + */ + private static class ColumnInfo { + final String name; + final String type; + final int precision; + final int scale; + + ColumnInfo(String name, String type, int precision, int scale) { + this.name = name; + this.type = type; + this.precision = precision; + this.scale = scale; + } + } + private final DatabaseCommandConfig dbCommandConfig; @Getter @@ -64,15 +92,15 @@ public class DatabaseDoGetContext { private String tableName; private final Map ticketWrapperMap = new ConcurrentHashMap<>(); - private final ReadWriteLock tickerWrapperMapRwLock = new ReentrantReadWriteLock(); + private final ReadWriteLock ticketWrapperMapRwLock = new ReentrantReadWriteLock(); private final Function initDatabaseFunc; @FunctionalInterface - public interface BuildQuerySqlFunc { - String apply(T t, U u, V v); + public interface BuildQuerySqlFunc { + R apply(T t, U u, V v); } - private final BuildQuerySqlFunc, String, String> buildQuerySqlFunc; + private final BuildQuerySqlFunc, String, String> buildQuerySqlFunc; private final Function jdbcType2ArrowType; private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(); @@ -90,80 +118,350 @@ public List getTaskConfigs() { return Collections.singletonList(new TaskConfig(this, 0)); } - private void prepare(){ + private void prepare() { DatabaseConnectConfig dbConnectConfig = dbCommandConfig.getDbConnectConfig(); String querySql; - - conn = this.initDatabaseFunc.apply(dbConnectConfig); - if (dbCommandConfig instanceof ScqlCommandJobConfig scqlReadJobConfig) { - querySql = scqlReadJobConfig.getCommandConfig(); - } else if (dbCommandConfig instanceof DatabaseTableQueryConfig dbTableQueryConfig) { - DatabaseTableConfig tableConfig = dbTableQueryConfig.getCommandConfig(); - this.tableName = tableConfig.tableName(); - querySql = this.buildQuerySqlFunc.apply(this.tableName, tableConfig.columns().stream().map(Common.DataColumn::getName).toList(), tableConfig.partition()); - this.schema = dbCommandConfig.getResultSchema(); - } else { - throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Unsupported read parameter type: " + dbCommandConfig.getClass()); + Connection localConn = null; + + try { + localConn = this.initDatabaseFunc.apply(dbConnectConfig); + this.conn = localConn; + + if (dbCommandConfig instanceof ScqlCommandJobConfig scqlReadJobConfig) { + querySql = scqlReadJobConfig.getCommandConfig(); + } else if (dbCommandConfig instanceof DatabaseTableQueryConfig dbTableQueryConfig) { + DatabaseTableConfig tableConfig = dbTableQueryConfig.getCommandConfig(); + this.tableName = tableConfig.tableName(); + querySql = this.buildQuerySqlFunc.apply(this.tableName, + tableConfig.columns().stream().map(Common.DataColumn::getName).toList(), + tableConfig.partition()); + this.schema = dbCommandConfig.getResultSchema(); + } else { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, + "Unsupported read parameter type: " + dbCommandConfig.getClass()); + } + + this.executeSqlTaskAndHandleResult(localConn, this.tableName, querySql); + } catch (Exception e) { + // Clean up created connection if prepare fails + if (localConn != null) { + try { + localConn.close(); + } catch (SQLException closeException) { + log.warn("Failed to close connection after prepare error", closeException); + } + } + throw e; } - this.executeSqlTaskAndHandleResult(conn, this.tableName, querySql); } private void executeSqlTaskAndHandleResult(Connection connection, String tableName, String querySql) { - log.info("database execute sql: {}", querySql); - Throwable throwable = null; + log.debug("Executing SQL on table {}: {}", + tableName != null ? tableName : "N/A", + querySql.length() > 200 ? querySql.substring(0, 200) + "..." : querySql); + + Throwable error = null; + readWriteLock.writeLock().lock(); try { - readWriteLock.writeLock().lock(); this.databaseMetaData = connection.getMetaData(); - queryStmt = connection.createStatement(); - resultSet = queryStmt.executeQuery(querySql); + Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(querySql); + + // Assign to instance variables after successful execution + this.queryStmt = stmt; + this.resultSet = rs; + + // SQL queries get column information from ResultSetMetaData, table queries from table metadata if (dbCommandConfig.getDbTypeEnum() == DatabaseTypeEnum.SQL) { + this.initArrowSchemaFromResultSet(rs); + } else { this.initArrowSchemaFromColumns(connection.getMetaData(), tableName); } - - } catch (SQLException e) { - log.error("sql execute error", e); - throwable = e; - throw DataproxyException.of(DataproxyErrorCode.DATABASE_ERROR, e.getMessage(), e); } catch (Exception e) { - throwable = e; - throw DataproxyException.of(DataproxyErrorCode.DATABASE_ERROR, "database execute sql error", e); + error = e; + // Clean up resources created in this execution (unassigned ones handled by GC, assigned ones by close()) + closeResourcesQuietly(this.resultSet, this.queryStmt); + this.resultSet = null; + this.queryStmt = null; + + String msg = e instanceof SQLException ? e.getMessage() : "database execute sql error"; + log.error("SQL execution failed on table {}: {}", tableName != null ? tableName : "N/A", e.getMessage(), e); + throw DataproxyException.of(DataproxyErrorCode.DATABASE_ERROR, msg, e); } finally { readWriteLock.writeLock().unlock(); - loadLazyConfig(throwable); + loadLazyConfig(error); + } + } + + /** + * Silently close resources (for cleanup in exception handling). + */ + private void closeResourcesQuietly(ResultSet rs, Statement stmt) { + if (rs != null) { + try { + rs.close(); + } catch (SQLException e) { + log.warn("Failed to close ResultSet during error handling", e); + } + } + if (stmt != null) { + try { + stmt.close(); + } catch (SQLException e) { + log.warn("Failed to close Statement during error handling", e); + } } } public void close() { - try{ - queryStmt.close(); - resultSet.close(); - conn.close(); + SQLException firstException = null; + + // Close resources in reverse order + if (resultSet != null) { + try { + resultSet.close(); + } catch (SQLException e) { + log.error("Failed to close ResultSet: {}", e.getMessage(), e); + firstException = e; + } + } + + if (queryStmt != null) { + try { + queryStmt.close(); + } catch (SQLException e) { + log.error("Failed to close Statement: {}", e.getMessage(), e); + if (firstException == null) { + firstException = e; + } + } + } + + if (conn != null) { + try { + conn.close(); + } catch (SQLException e) { + log.error("Failed to close Connection: {}", e.getMessage(), e); + if (firstException == null) { + firstException = e; + } + } + } + + if (firstException != null) { + throw new RuntimeException("Error closing database resources", firstException); + } + } + + /** + * Initialize Arrow Schema from DatabaseMetaData. + * Used for table query scenarios to get complete column information of the table. + * + * @param metaData Database metadata + * @param tableName Table name + * @throws SQLException SQL exception + */ + private void initArrowSchemaFromColumns(DatabaseMetaData metaData, String tableName) throws SQLException { + String schemaName = dbCommandConfig.getDbConnectConfig().database(); + log.debug("Querying column information: schema={}, catalog=null, tableName={}", schemaName, tableName); + + List columnInfos = new ArrayList<>(); + try (ResultSet columns = metaData.getColumns(null, schemaName, tableName, null)) { + while (columns.next()) { + String columnName = columns.getString("COLUMN_NAME"); + String columnType = columns.getString("TYPE_NAME"); + + // Safely get precision and scale information + int precision = safeGetInt(() -> columns.getInt("COLUMN_SIZE"), columnName, "COLUMN_SIZE", 0); + int scale = safeGetInt(() -> columns.getInt("DECIMAL_DIGITS"), columnName, "DECIMAL_DIGITS", -1); + + columnInfos.add(new ColumnInfo(columnName, columnType, precision, scale)); + } + } + + schema = buildSchemaFromColumnInfos(columnInfos); + log.debug("Built schema with {} columns for table {}", columnInfos.size(), tableName); + } + + /** + * Initialize Arrow Schema from ResultSetMetaData. + * Used for SQL query scenarios to get actual column information of query results. + * + * @param resultSet Query result set + * @throws SQLException SQL exception + */ + private void initArrowSchemaFromResultSet(ResultSet resultSet) throws SQLException { + ResultSetMetaData metaData = resultSet.getMetaData(); + int columnCount = metaData.getColumnCount(); + List columnInfos = new ArrayList<>(columnCount); + + for (int i = 1; i <= columnCount; i++) { + String columnName = metaData.getColumnName(i); + String columnType = metaData.getColumnTypeName(i); + + final int index = i; + // Safely get precision and scale information + int precision = safeGetInt(() -> metaData.getPrecision(index), columnName, "precision", 0); + int scale = safeGetInt(() -> metaData.getScale(index), columnName, "scale", -1); + + columnInfos.add(new ColumnInfo(columnName, columnType, precision, scale)); + } + + schema = buildSchemaFromColumnInfos(columnInfos); + log.debug("Built schema with {} columns from SQL query result", columnCount); + } + + /** + * Build Arrow Schema from column information list (unified Schema building logic). + * + * @param columnInfos Column information list + * @return Arrow Schema + */ + private Schema buildSchemaFromColumnInfos(List columnInfos) { + List fields = new ArrayList<>(columnInfos.size()); + + for (ColumnInfo info : columnInfos) { + ArrowType arrowType = determineArrowType( + info.type, + () -> info.precision, + () -> info.scale, + info.name + ); + fields.add(new Field(info.name, FieldType.nullable(arrowType), null)); + } + + return new Schema(fields); + } + + /** + * Safely get integer value, catch SQLException and return default value. + * + * @param supplier SQL integer supplier + * @param columnName Column name (for logging) + * @param attributeName Attribute name (for logging) + * @param defaultValue Default value + * @return Retrieved integer value or default value + */ + private int safeGetInt(SqlIntSupplier supplier, String columnName, String attributeName, int defaultValue) { + try { + return supplier.getAsInt(); } catch (SQLException e) { - log.error("query jdbc close error: {}", e.getMessage()); - throw new RuntimeException(e); + log.warn("Failed to get {} for column {}: {}", attributeName, columnName, e.getMessage()); + return defaultValue; } + } + /** + * Determine Arrow type based on JDBC type name. + * + * @param columnType JDBC type name (e.g., "TIMESTAMP(6)", "DECIMAL", "TIME") + * @param precisionSupplier Precision supplier (for DECIMAL type) + * @param scaleSupplier Scale/precision supplier (for time type precision in DECIMAL type) + * @param columnName Column name (for logging) + * @return Arrow type + */ + private ArrowType determineArrowType(String columnType, + java.util.function.Supplier precisionSupplier, + java.util.function.Supplier scaleSupplier, + String columnName) { + // Check if it's DECIMAL type + if (isDecimalType(columnType)) { + int precision = precisionSupplier.get(); + int scale = scaleSupplier.get(); + // If retrieval fails (returns 0 or negative), use default value + if (precision <= 0) { + precision = DEFAULT_DECIMAL_PRECISION; + } + if (scale < 0) { + scale = DEFAULT_DECIMAL_SCALE; + } + return new ArrowType.Decimal(precision, scale, 128); + } + + // For time types, try to construct type name with precision from precision information + String typeNameWithPrecision = addPrecisionToTypeName(columnType, scaleSupplier, columnName); + return this.jdbcType2ArrowType.apply(typeNameWithPrecision); } - private void initArrowSchemaFromColumns(DatabaseMetaData metaData, String tableName) throws SQLException { - ResultSet columns = metaData.getColumns(null, null, tableName, null); - List fields = new ArrayList<>(); - while (columns.next()) { - String columnName = columns.getString("COLUMN_NAME"); - String columnType = columns.getString("TYPE_NAME"); + /** + * Extract base type name (remove precision information). + * Example: DECIMAL(38,10) -> DECIMAL, TIMESTAMP(6) -> TIMESTAMP + * + * @param columnType JDBC type name + * @return Base type name (uppercase, precision information removed) + */ + private String extractBaseTypeName(String columnType) { + if (columnType == null) { + return null; + } + return columnType.toUpperCase().replaceAll("\\([^)]*\\)", "").trim(); + } - ArrowType arrowType = this.jdbcType2ArrowType.apply(columnType); - Field field = new Field(columnName, FieldType.nullable(arrowType), null); - fields.add(field); + /** + * Check if it's DECIMAL type (including aliases). + * Supports type names with precision information, such as "DECIMAL(38,10)", "NUMERIC(20,5)", etc. + * + * @param columnType JDBC type name + * @return Whether it's DECIMAL type + */ + private boolean isDecimalType(String columnType) { + if (columnType == null) { + return false; } - columns.close(); - schema = new Schema(fields); + String baseType = extractBaseTypeName(columnType); + return DECIMAL_TYPE_NAMES.contains(baseType); + } + + /** + * Add precision information to time type (if TYPE_NAME doesn't contain precision but DECIMAL_DIGITS has value). + * + * @param columnType JDBC type name + * @param scaleSupplier Precision supplier (DECIMAL_DIGITS) + * @param columnName Column name (for logging) + * @return Type name with precision (if applicable) + */ + private String addPrecisionToTypeName(String columnType, + java.util.function.Supplier scaleSupplier, + String columnName) { + // If type name already contains precision information, return directly + if (columnType != null && columnType.contains("(")) { + return columnType; + } + + // Extract base type name (remove possible precision information) + String baseType = extractBaseTypeName(columnType); + + // Check if it's a time type that requires precision information + if (!isTimeType(baseType)) { + return columnType; + } + + // Try to get precision from DECIMAL_DIGITS + int decimalDigits = scaleSupplier.get(); + if (decimalDigits >= 0) { + String typeNameWithPrecision = baseType + "(" + decimalDigits + ")"; + log.debug("Constructed type name with precision: {} -> {} for column {}", + columnType, typeNameWithPrecision, columnName); + return typeNameWithPrecision; + } + + return columnType; + } + + /** + * Check if it's a time type that requires precision information. + * + * @param baseType Base type name (uppercase, precision information removed) + * @return Whether it's a time type + */ + private boolean isTimeType(String baseType) { + return TIME_TYPE_NAMES.contains(baseType); } private void loadLazyConfig(Throwable throwable) { - tickerWrapperMapRwLock.writeLock().lock(); + ticketWrapperMapRwLock.writeLock().lock(); try { if (ticketWrapperMap.isEmpty()) { return; @@ -173,8 +471,8 @@ private void loadLazyConfig(Throwable throwable) { throw new IllegalArgumentException("#getTaskConfigs is empty"); } - log.info("config list size: {}", taskConfigs.size()); - log.info("ticketWrapperMap size: {}", ticketWrapperMap.size()); + log.debug("Loading lazy config: taskConfigs size={}, ticketWrapperMap size={}", + taskConfigs.size(), ticketWrapperMap.size()); int index = 0; ParamWrapper paramWrapper; @@ -184,10 +482,10 @@ private void loadLazyConfig(Throwable throwable) { if (index < taskConfigs.size()) { TaskConfig taskConfig = taskConfigs.get(index); taskConfig.setError(throwable); - log.info("Load lazy taskConfig: {}", JsonUtils.toString(taskConfig)); + log.debug("Loading lazy taskConfig at index {}", index); paramWrapper.setParamIfAbsent(taskConfig); } else { - log.info("Set the remaining ticketWrapperMap to a default TaskConfig that doesn't read data. index: {},", index); + log.debug("Setting default TaskConfig for remaining ticket at index {}", index); TaskConfig taskConfig = new TaskConfig(this, 0); taskConfig.setError(throwable); paramWrapper.setParamIfAbsent(taskConfig); @@ -195,7 +493,7 @@ private void loadLazyConfig(Throwable throwable) { index++; } } finally { - tickerWrapperMapRwLock.writeLock().unlock(); + ticketWrapperMapRwLock.writeLock().unlock(); } } diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetTaskContext.java b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetTaskContext.java index 718cc3d..3a016b8 100644 --- a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetTaskContext.java +++ b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetTaskContext.java @@ -58,7 +58,7 @@ public void start() { sender.putOver(); hasNext.set(false); } catch (InterruptedException e) { - log.error("read interrupted", e); + log.warn("Database read task interrupted for table {}", taskConfig.getContext().getTableName(), e); Thread.currentThread().interrupt(); } }); @@ -66,7 +66,7 @@ public void start() { public void cancel() { if (readFuture != null && !readFuture.isDone()) { - log.info("cancel read task..."); + log.debug("Cancelling read task for table {}", taskConfig.getContext().getTableName()); readFuture.cancel(true); } } diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseReader.java b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseReader.java index 3821c59..ed941a6 100644 --- a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseReader.java +++ b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseReader.java @@ -40,15 +40,15 @@ public DatabaseReader(BufferAllocator allocator, TaskConfig taskConfig) { @Override public boolean loadNextBatch() throws IOException { Optional.ofNullable(taskConfig.getError()).ifPresent(e -> { - log.error("TaskConfig is happened error: {}", e.getMessage()); - throw new RuntimeException(e); + log.error("TaskConfig has an error: {}", e.getMessage(), e); + throw new RuntimeException("TaskConfig error: " + e.getMessage(), e); }); - if(dbDoGetTaskContext == null) { + if (dbDoGetTaskContext == null) { prepare(); } - if(dbDoGetTaskContext.hasNext()) { + if (dbDoGetTaskContext.hasNext()) { dbDoGetTaskContext.putNextPatchData(); return true; } diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseRecordSender.java b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseRecordSender.java index eefa7dc..4793d51 100644 --- a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseRecordSender.java +++ b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseRecordSender.java @@ -20,16 +20,18 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.ArrowType; -import org.secretflow.dataproxy.plugin.database.utils.Record; import org.secretflow.dataproxy.core.converter.*; import org.secretflow.dataproxy.core.reader.AbstractSender; import org.secretflow.dataproxy.core.visitor.*; +import org.secretflow.dataproxy.plugin.database.utils.Record; import javax.annotation.Nonnull; import java.sql.DatabaseMetaData; import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.*; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.concurrent.LinkedBlockingQueue; @Slf4j @@ -41,6 +43,7 @@ public class DatabaseRecordSender extends AbstractSender { private final String tableName; private final DatabaseMetaData metaData; + static { SmallIntVectorConverter smallIntVectorConverter = new SmallIntVectorConverter(new ShortValueVisitor(), null); TinyIntVectorConverter tinyIntVectorConverter = new TinyIntVectorConverter(new ByteValueVisitor(), smallIntVectorConverter); @@ -52,12 +55,22 @@ public class DatabaseRecordSender extends AbstractSender { ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Int, intVectorConverter); ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Utf8, new VarCharVectorConverter(new ByteArrayValueVisitor())); + ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.LargeUtf8, new LargeUtf8VectorConverter(new ByteArrayValueVisitor())); + ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Binary, new BinaryVectorConverter(new ByteArrayValueVisitor())); + ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.LargeBinary, new LargeBinaryVectorConverter(new ByteArrayValueVisitor())); ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.FloatingPoint, float8VectorConverter); ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Bool, new BitVectorConverter(new BooleanValueVisitor())); ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Date, new DateDayVectorConverter(new IntegerValueVisitor(), dateMilliVectorConverter)); - ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Time, new TimeMilliVectorConvertor(new IntegerValueVisitor(), null)); - ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Timestamp, new TimeStampNanoVectorConverter(new LongValueVisitor())); + // Chain TimeMicroVectorConvertor after TimeMilliVectorConvertor to support both Time32 and Time64 + TimeMicroVectorConvertor timeMicroVectorConvertor = new TimeMicroVectorConvertor(new LongValueVisitor(), null); + ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Time, new TimeMilliVectorConvertor(new IntegerValueVisitor(), timeMicroVectorConvertor)); + // Chain TimeStampMicroVectorConverter after TimeStampMilliVectorConverter to support both millisecond and microsecond precision + TimeStampMicroVectorConverter timeStampMicroVectorConverter = new TimeStampMicroVectorConverter(new LongValueVisitor(), null); + ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Timestamp, new TimeStampMilliVectorConverter(new LongValueVisitor(), timeStampMicroVectorConverter)); + ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Decimal, new Decimal128VectorConverter(new BigDecimalValueVisitor())); + ARROW_TYPE_ID_FIELD_CONSUMER_MAP.put(ArrowType.ArrowTypeID.Interval, new IntervalVectorConverter(new ObjectValueVisitor())); } + /** * Constructor * @@ -75,41 +88,32 @@ public DatabaseRecordSender(int estimatedRecordCount, LinkedBlockingQueue filedVectorOpt; - FieldVector vector; - ArrowType.ArrowTypeID arrowTypeID; - - Object recordColumnValue; - - ResultSet columns = metaData.getColumns(null, null, tableName, null); - - while (columns.next()) { - String name = columns.getString("COLUMN_NAME"); - - filedVectorOpt = Optional.ofNullable(this.fieldVectorMap.get(name)); - - if (filedVectorOpt.isPresent()) { - vector = filedVectorOpt.get(); - recordColumnValue = record.get(name); - arrowTypeID = vector.getField().getType().getTypeID(); - if (Objects.isNull(recordColumnValue)) { - vector.setNull(takeRecordCount); - continue; - } - ValueConversionStrategy converter = ARROW_TYPE_ID_FIELD_CONSUMER_MAP.get(arrowTypeID); - if (converter != null) { - converter.convertAndSet(vector, takeRecordCount, recordColumnValue); - } else { - log.warn("No converter found for ArrowTypeID: {} (column: {})", arrowTypeID, name); - } - - } + this.initRecordColumn2FieldMap(); + + // Directly iterate Record's column data, Record column names match Schema column names + Map data = record.getData(); + for (Map.Entry entry : data.entrySet()) { + String columnName = entry.getKey(); + Object recordColumnValue = entry.getValue(); + + FieldVector vector = this.fieldVectorMap.get(columnName); + if (vector == null) { + log.warn("Column {} not found in fieldVectorMap", columnName); + continue; + } + + if (Objects.isNull(recordColumnValue)) { + vector.setNull(takeRecordCount); + continue; + } + + ArrowType.ArrowTypeID arrowTypeID = vector.getField().getType().getTypeID(); + ValueConversionStrategy converter = ARROW_TYPE_ID_FIELD_CONSUMER_MAP.get(arrowTypeID); + if (converter != null) { + converter.convertAndSet(vector, takeRecordCount, recordColumnValue); + } else { + log.warn("No converter found for ArrowTypeID: {} (column: {})", arrowTypeID, columnName); } - columns.close(); - } catch (SQLException e) { - throw new RuntimeException(e); } } @@ -120,14 +124,16 @@ protected boolean isOver(Record record) { @Override public void putOver() throws InterruptedException { - this.put(new Record()); + Record lastRecord = new Record(); + lastRecord.setLast(true); + this.put(lastRecord); } public boolean equalsIgnoreCase(String s1, String s2) { return s1 == null ? s2 == null : s1.equalsIgnoreCase(s2); } - private synchronized void initRecordColumn2FieldMap(DatabaseMetaData metaData, String tableName) throws SQLException { + private synchronized void initRecordColumn2FieldMap() { if (isInit) { return; } @@ -137,25 +143,16 @@ private synchronized void initRecordColumn2FieldMap(DatabaseMetaData metaData, S if (Objects.isNull(root)) { return; } - List fieldVectors = root.getFieldVectors(); - - ResultSet columns = metaData.getColumns(null, null, tableName, null); - - Optional first; - - while (columns.next()) { - String name = columns.getString("COLUMN_NAME"); - first = fieldVectors.stream() - .filter(fieldVector -> equalsIgnoreCase(fieldVector.getName(), name)) - .findFirst(); - if (first.isPresent()) { - fieldVectorMap.put(name, first.get()); - } else { - log.debug("columnName: {} not in fieldVectors", name); - } + // Directly build column name to FieldVector mapping from VectorSchemaRoot + // Schema already contains correct column information (from ResultSetMetaData or DatabaseMetaData) + List fieldVectors = root.getFieldVectors(); + for (FieldVector fieldVector : fieldVectors) { + String columnName = fieldVector.getName(); + fieldVectorMap.put(columnName, fieldVector); + log.trace("Mapped column: {} to FieldVector", columnName); } - columns.close(); + isInit = true; } } diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/utils/Record.java b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/utils/Record.java index d286a5c..45033ca 100644 --- a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/utils/Record.java +++ b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/utils/Record.java @@ -17,13 +17,18 @@ package org.secretflow.dataproxy.plugin.database.utils; import lombok.Setter; +import lombok.extern.slf4j.Slf4j; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.io.StringWriter; +import java.sql.*; import java.util.*; @Setter +@Slf4j public class Record { // isLast means that this record is the last record private boolean isLast = false; @@ -62,9 +67,126 @@ private void fromResultSet(ResultSet rs) throws SQLException { int columnCount = metaData.getColumnCount(); for (int i = 1; i <= columnCount; i++) { String columnName = metaData.getColumnName(i); - Object columnValue = rs.getObject(i); + int columnType = metaData.getColumnType(i); + String columnTypeName = metaData.getColumnTypeName(i); + + + Object columnValue; + try { + columnValue = switch (columnType) { + // CLOB type - use stream reading for better performance and support for large CLOB + case Types.CLOB, Types.NCLOB -> { + Clob clob = rs.getClob(i); + if (clob == null) { + yield null; + } + try { + // Check CLOB size, throw exception if exceeds String maximum length + long clobLength = clob.length(); + if (clobLength > Integer.MAX_VALUE) { + clob.free(); + throw new SQLException(String.format( + "CLOB size %d exceeds String maximum length (%d) for column %s. " + + "Cannot read CLOB larger than 2GB into a String.", + clobLength, Integer.MAX_VALUE, columnName)); + } + // Use stream reading to avoid loading entire content into memory + try (Reader reader = clob.getCharacterStream()) { + String clobContent = readCharacterStreamAsString(reader); + clob.free(); + yield clobContent; + } + } catch (SQLException e) { + throw e; + } catch (Exception e) { + log.error("Failed to read CLOB stream for column {}, will try getString(): {}", columnName, e.getMessage()); + clob.free(); + yield rs.getString(i); + } + } + + // LONGVARBINARY type - large binary data, try getBytes() first, fallback to stream reading + case Types.LONGVARBINARY -> { + try { + byte[] bytes = rs.getBytes(i); + if (bytes != null) { + yield bytes; + } + // If getBytes() returns null, try stream reading + InputStream inputStream = rs.getBinaryStream(i); + if (inputStream == null) { + yield null; + } + try (inputStream) { + yield readBinaryStreamAsBytes(inputStream); + } + } catch (Exception e) { + log.error("Failed to read LONGVARBINARY for column {}: {}", columnName, e.getMessage()); + yield null; + } + } + + // BLOB type - use stream reading for better performance and support for large BLOB + case Types.BLOB -> { + Blob blob = rs.getBlob(i); + if (blob == null) { + yield null; + } + try { + // Check BLOB size, throw exception if exceeds byte[] maximum length + long blobLength = blob.length(); + if (blobLength > Integer.MAX_VALUE) { + blob.free(); + throw new SQLException(String.format( + "BLOB size %d exceeds byte array maximum length (%d) for column %s. " + + "Cannot read BLOB larger than 2GB into a byte array.", + blobLength, Integer.MAX_VALUE, columnName)); + } + // Use stream reading to avoid loading entire content into memory + try (InputStream inputStream = blob.getBinaryStream()) { + byte[] blobContent = readBinaryStreamAsBytes(inputStream); + blob.free(); + yield blobContent; + } + } catch (SQLException e) { + throw e; + } catch (Exception e) { + log.error("Failed to read BLOB stream for column {}, will try getBytes(): {}", columnName, e.getMessage()); + blob.free(); + yield rs.getBytes(i); + } + } + + // LONGVARCHAR type - may need stream reading + case Types.LONGVARCHAR, Types.LONGNVARCHAR -> { + String strValue = rs.getString(i); + if (strValue != null) { + yield strValue; + } + + // If getString() returns null, try stream reading + Reader reader = rs.getCharacterStream(i); + if (reader == null) { + yield null; + } + try (reader) { + yield readCharacterStreamAsString(reader); + } catch (Exception e) { + log.error("Failed to read LONGVARCHAR stream for column {}: {}", columnName, e.getMessage()); + yield null; + } + } + + default -> rs.getObject(i); + }; + } catch (SQLException e) { + log.error("Failed to read column {} (type: {}, typeName: {}): {}", + columnName, columnType, columnTypeName, e.getMessage()); + columnValue = rs.getObject(i); + } + this.set(columnName, columnValue); - this.setColumnType(i, metaData.getColumnType(i)); + this.setColumnType(i, columnType); } } @@ -79,5 +201,47 @@ public Map getData() { } return temp; } + + /** + * Read content from character stream and convert to string. + * + * Note: This method is limited by String maximum length (Integer.MAX_VALUE, approximately 2GB). + * If character stream content exceeds this limit, may throw OutOfMemoryError or String length exception. + * + * @param reader Character stream reader + * @return Read string content + * @throws IOException Thrown when reading fails + */ + private String readCharacterStreamAsString(Reader reader) throws IOException { + try (StringWriter writer = new StringWriter()) { + char[] buffer = new char[8192]; + int length; + while ((length = reader.read(buffer)) != -1) { + writer.write(buffer, 0, length); + } + return writer.toString(); + } + } + + /** + * Read content from byte stream and convert to byte array. + * + * Note: This method is limited by byte[] maximum length (Integer.MAX_VALUE, approximately 2GB). + * If byte stream content exceeds this limit, may throw OutOfMemoryError or array length exception. + * + * @param inputStream Byte stream + * @return Read byte array + * @throws IOException Thrown when reading fails + */ + private byte[] readBinaryStreamAsBytes(InputStream inputStream) throws IOException { + try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { + byte[] buffer = new byte[8192]; + int length; + while ((length = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, length); + } + return outputStream.toByteArray(); + } + } } diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/writer/DatabaseRecordWriter.java b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/writer/DatabaseRecordWriter.java index 5976a1a..7bb6730 100644 --- a/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/writer/DatabaseRecordWriter.java +++ b/dataproxy-plugins/dataproxy-plugin-database/src/main/java/org/secretflow/dataproxy/plugin/database/writer/DatabaseRecordWriter.java @@ -19,16 +19,24 @@ import lombok.extern.slf4j.Slf4j; import org.apache.arrow.vector.*; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.secretflow.dataproxy.plugin.database.config.DatabaseCommandConfig; import org.secretflow.dataproxy.core.writer.Writer; +import org.secretflow.dataproxy.plugin.database.config.DatabaseCommandConfig; import org.secretflow.dataproxy.plugin.database.config.DatabaseConnectConfig; import org.secretflow.dataproxy.plugin.database.config.DatabaseTableConfig; import org.secretflow.dataproxy.plugin.database.config.DatabaseWriteConfig; import org.secretflow.dataproxy.plugin.database.utils.Record; +import java.math.BigDecimal; import java.nio.charset.StandardCharsets; -import java.sql.*; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; @@ -67,10 +75,49 @@ public interface BuildMultiInsertSqlFunc { private Connection connection; private final boolean supportMultiInsert; - // TODO: parse partition - // Partitioning is not supported in the current version + /** + * Parse partition string to map + * Format: "key1=value1,key2=value2" or "key=value" + * Example: "dt=20240101" or "dt=20240101,region=us" + * + * @param partition Partition string (e.g., "dt=20240101") + * @return Map of partition key-value pairs + */ public static Map parsePartition(String partition) { - return new LinkedHashMap<>(); + Map partitionMap = new LinkedHashMap<>(); + if (partition == null || partition.trim().isEmpty()) { + return partitionMap; + } + + String trimmed = partition.trim(); + // Split by comma for multiple partitions + String[] parts = trimmed.split(","); + + for (String part : parts) { + part = part.trim(); + if (part.isEmpty()) { + continue; + } + + // Split by '=' to get key and value + int equalIndex = part.indexOf('='); + if (equalIndex <= 0 || equalIndex >= part.length() - 1) { + log.warn("Invalid partition format: {}, expected format: key=value", part); + continue; + } + + String key = part.substring(0, equalIndex).trim(); + String value = part.substring(equalIndex + 1).trim(); + + if (key.isEmpty() || value.isEmpty()) { + log.warn("Invalid partition format: {}, key or value is empty", part); + continue; + } + + partitionMap.put(key, value); + } + + return partitionMap; } public DatabaseRecordWriter(DatabaseWriteConfig commandConfig, @@ -116,12 +163,43 @@ private Connection initDatabaseClient(DatabaseConnectConfig dbConnectConfig) { return this.initFunc.apply(dbConnectConfig); } - private void prepare(){ - - connection = initDatabaseClient(dbConnectConfig); + private void prepare() { + Connection localConnection = null; + try { + localConnection = initDatabaseClient(dbConnectConfig); + this.connection = localConnection; - preProcessing(dbTableConfig.tableName()); + // Set autoCommit based on insert mode + if (supportMultiInsert) { + // Batch insert mode: use manual transaction, batch commit (better performance) + try { + connection.setAutoCommit(false); + log.info("Batch insert mode: autoCommit=false, will commit every {} rows", BATCH_NUM); + } catch (SQLException e) { + log.warn("Failed to set autoCommit=false, using default", e); + } + } else { + // Single insert mode: use auto-commit (simple and direct, each insert committed immediately) + try { + connection.setAutoCommit(true); + log.info("Single insert mode: autoCommit=true, each insert auto-committed"); + } catch (SQLException e) { + log.warn("Failed to set autoCommit=true, using default", e); + } + } + preProcessing(dbTableConfig.tableName()); + } catch (Exception e) { + // Clean up created connection if prepare fails + if (localConnection != null) { + try { + localConnection.close(); + } catch (SQLException closeException) { + log.warn("Failed to close connection after prepare error", closeException); + } + } + throw e; + } } /** @@ -132,9 +210,14 @@ private void prepare(){ * @return value */ private Object getValue(FieldVector fieldVector, int index) { - if (fieldVector == null || index < 0 || fieldVector.getObject(index) == null) { + if (fieldVector == null || index < 0) { return null; } + + if (fieldVector.isNull(index)) { + return null; + } + ArrowType.ArrowTypeID arrowTypeID = fieldVector.getField().getType().getTypeID(); switch (arrowTypeID) { @@ -142,19 +225,204 @@ private Object getValue(FieldVector fieldVector, int index) { if (fieldVector instanceof IntVector || fieldVector instanceof BigIntVector || fieldVector instanceof SmallIntVector || fieldVector instanceof TinyIntVector) { return fieldVector.getObject(index); } - log.warn("Type INT is not IntVector or BigIntVector or SmallIntVector or TinyIntVector, value is: {}", fieldVector.getObject(index).toString()); + log.debug("Unexpected Int vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); } case FloatingPoint -> { - if (fieldVector instanceof Float4Vector | fieldVector instanceof Float8Vector) { - return fieldVector.getObject(index); + if (fieldVector instanceof Float4Vector vector) { + return vector.get(index); + } else if (fieldVector instanceof Float8Vector vector) { + return vector.get(index); } - log.warn("Type FloatingPoint is not Float4Vector or Float8Vector, value is: {}", fieldVector.getObject(index).toString()); + log.debug("Unexpected FloatingPoint vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); } case Utf8 -> { if (fieldVector instanceof VarCharVector vector) { return new String(vector.get(index), StandardCharsets.UTF_8); } - log.warn("Type Utf8 is not VarCharVector, value is: {}", fieldVector.getObject(index).toString()); + log.debug("Unexpected Utf8 vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); + } + case LargeUtf8 -> { + if (fieldVector instanceof LargeVarCharVector vector) { + return vector.get(index); + } + log.debug("Unexpected LargeUtf8 vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); + } + case Binary -> { + if (fieldVector instanceof VarBinaryVector vector) { + return vector.get(index); + } + log.debug("Unexpected Binary vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); + } + case LargeBinary -> { + if (fieldVector instanceof LargeVarBinaryVector vector) { + return vector.get(index); + } + log.debug("Unexpected LargeBinary vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); + } + case Decimal -> { + if (fieldVector instanceof DecimalVector vector) { + ArrowType.Decimal decimalType = (ArrowType.Decimal) fieldVector.getField().getType(); + BigDecimal value = vector.getObject(index); + return value; + } + log.debug("Unexpected Decimal vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); + } + case Date -> { + if (fieldVector instanceof DateDayVector vector) { + // DateDayVector stores days since 1970-01-01 + // Convert to java.sql.Date (date only) + int days = vector.get(index); + LocalDate date = LocalDate.ofEpochDay(days); + return java.sql.Date.valueOf(date); + } else if (fieldVector instanceof DateMilliVector vector) { + // DateMilliVector stores milliseconds since 1970-01-01 00:00:00 UTC + // Convert to java.sql.Timestamp (date+time) as database fields are usually TIMESTAMP + long millis = vector.get(index); + return new java.sql.Timestamp(millis); + } + log.debug("Unexpected Date vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); + } + case Time -> { + // Time vector getObject() may return LocalTime or LocalDateTime + // Use getObject() directly instead of manual calculation to avoid timezone and calculation errors + Object timeObj = fieldVector.getObject(index); + if (timeObj instanceof LocalTime localTime) { + return java.sql.Time.valueOf(localTime); + } else if (timeObj instanceof LocalDateTime localDateTime) { + // If LocalDateTime is returned, extract time part + LocalTime localTime = localDateTime.toLocalTime(); + return java.sql.Time.valueOf(localTime); + } else if (timeObj instanceof Integer intValue) { + // TimeSecVector may return Integer (seconds) + int seconds = intValue; + int hours = seconds / (60 * 60); + int minutes = (seconds % (60 * 60)) / 60; + int secs = seconds % 60; + LocalTime time = LocalTime.of(hours, minutes, secs); + return java.sql.Time.valueOf(time); + } else if (timeObj instanceof Long longValue) { + // TimeMilliVector/TimeMicroVector/TimeNanoVector may return Long + // Try to convert to LocalTime + if (fieldVector instanceof TimeMilliVector) { + // TimeMilliVector: milliseconds since midnight + int millis = longValue.intValue(); + if (millis < 0) { + log.debug("TimeMilliVector value is negative: {}, using 0", millis); + millis = 0; + } + int hours = millis / (1000 * 60 * 60); + int minutes = (millis % (1000 * 60 * 60)) / (1000 * 60); + int seconds = (millis % (1000 * 60)) / 1000; + int nanos = (millis % 1000) * 1000000; + if (hours >= 24) { + hours = hours % 24; + } + LocalTime time = LocalTime.of(hours, minutes, seconds, nanos); + return java.sql.Time.valueOf(time); + } else if (fieldVector instanceof TimeMicroVector) { + // TimeMicroVector: microseconds since midnight + long micros = longValue; + if (micros < 0) { + log.debug("TimeMicroVector value is negative: {}, using 0", micros); + micros = 0; + } + long totalSeconds = micros / 1_000_000; + long nanos = (micros % 1_000_000) * 1000; + int hours = (int) (totalSeconds / (60 * 60)); + int minutes = (int) ((totalSeconds % (60 * 60)) / 60); + int seconds = (int) (totalSeconds % 60); + if (hours >= 24) { + hours = hours % 24; + } + LocalTime time = LocalTime.of(hours, minutes, seconds, (int) nanos); + return java.sql.Time.valueOf(time); + } else if (fieldVector instanceof TimeNanoVector) { + // TimeNanoVector: nanoseconds since midnight + long nanos = longValue; + if (nanos < 0) { + log.debug("TimeNanoVector value is negative: {}, using 0", nanos); + nanos = 0; + } + long totalSeconds = nanos / 1_000_000_000; + int nanoOfSecond = (int) (nanos % 1_000_000_000); + int hours = (int) (totalSeconds / (60 * 60)); + int minutes = (int) ((totalSeconds % (60 * 60)) / 60); + int seconds = (int) (totalSeconds % 60); + if (hours >= 24) { + hours = hours % 24; + } + LocalTime time = LocalTime.of(hours, minutes, seconds, nanoOfSecond); + return java.sql.Time.valueOf(time); + } else { + log.debug("Time type conversion for Long value not fully supported, using default"); + return timeObj; + } + } + log.debug("Time conversion failed, value type: {}", + timeObj != null ? timeObj.getClass().getName() : "null"); + return timeObj; + } + case Timestamp -> { + // Important: TimeStampVector stores UTC time (microseconds/milliseconds since Unix epoch) + // Cannot use getObject() returned LocalDateTime, because Timestamp.valueOf(LocalDateTime) + // treats LocalDateTime as local time in system default timezone, causing timezone conversion errors + // Should directly use raw microseconds/milliseconds to create Timestamp + if (fieldVector instanceof TimeStampMilliVector vector) { + // TimeStampMilliVector stores milliseconds (since Unix epoch, UTC time) + long millis = vector.get(index); + return new java.sql.Timestamp(millis); + } else if (fieldVector instanceof TimeStampMicroVector vector) { + // TimeStampMicroVector stores microseconds (since Unix epoch, UTC time) + long micros = vector.get(index); + long millis = micros / 1000; + int microsPart = (int) (micros % 1000); // Microseconds part (0-999) + java.sql.Timestamp timestamp = new java.sql.Timestamp(millis); + // Timestamp constructor automatically sets nanoseconds for milliseconds part + // We need to preserve existing millisecond nanoseconds and add microsecond nanoseconds + // getNanos() returns nanoseconds part of seconds (0-999,999,999) + int existingNanos = timestamp.getNanos(); + int additionalNanos = microsPart * 1000; // Convert microseconds to nanoseconds + timestamp.setNanos(existingNanos + additionalNanos); + return timestamp; + } else if (fieldVector instanceof TimeStampNanoVector vector) { + // TimeStampNanoVector stores nanoseconds (since Unix epoch, UTC time) + long nanos = vector.get(index); + long millis = nanos / 1_000_000; + int nanosPart = (int) (nanos % 1_000_000); + java.sql.Timestamp timestamp = new java.sql.Timestamp(millis); + timestamp.setNanos((int) (timestamp.getNanos() / 1000000 * 1000000 + nanosPart)); + return timestamp; + } else if (fieldVector instanceof TimeStampSecVector vector) { + // TimeStampSecVector stores seconds (since Unix epoch, UTC time) + long seconds = vector.get(index); + long millis = seconds * 1000; + return new java.sql.Timestamp(millis); + } + log.debug("Unexpected Timestamp vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); + } + case Interval -> { + // Interval type is complex, need to handle based on specific subtype + if (fieldVector instanceof IntervalYearVector vector) { + // YEAR_MONTH interval + return vector.getObject(index); + } else if (fieldVector instanceof IntervalDayVector vector) { + // DAY_TIME interval + return vector.getObject(index); + } else if (fieldVector instanceof IntervalMonthDayNanoVector vector) { + // MONTH_DAY_NANO interval + return vector.getObject(index); + } + log.debug("Unexpected Interval vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); } case Null -> { return null; @@ -163,63 +431,143 @@ private Object getValue(FieldVector fieldVector, int index) { if (fieldVector instanceof BitVector vector) { return vector.get(index) == 1; } - log.warn("Type BOOL is not BitVector, value is: {}", fieldVector.getObject(index).toString()); + log.debug("Unexpected Bool vector type: {}, using getObject() fallback", fieldVector.getClass().getSimpleName()); + return fieldVector.getObject(index); } default -> { - log.warn("Not implemented type: {}, will use default function", arrowTypeID); + log.debug("Unsupported Arrow type: {}, using getObject() fallback", arrowTypeID); return fieldVector.getObject(index); } } - return null; + } + + /** + * Build single record from VectorSchemaRoot. + * + * @param root VectorSchemaRoot + * @param rowIndex Row index + * @return Record data Map + */ + private Map buildRecord(VectorSchemaRoot root, int rowIndex) { + Record record = new Record(); + int columnCount = root.getFieldVectors().size(); + for (int columnIndex = 0; columnIndex < columnCount; columnIndex++) { + String columnName = root.getVector(columnIndex).getField().getName().toLowerCase(); + record.set(columnName, this.getValue(root.getFieldVectors().get(columnIndex), rowIndex)); + } + return record.getData(); } @Override public void write(VectorSchemaRoot root) { final int batchSize = root.getRowCount(); - log.info("database writer batchSize: {}", batchSize); - int columnCount = root.getFieldVectors().size(); + log.debug("Writing {} rows to table {}", batchSize, tableName); - String columnName; + if (supportMultiInsert) { + writeWithMultiInsert(root, batchSize); + } else { + writeWithSingleInsert(root, batchSize); + } + } - if(supportMultiInsert) { - List> multiRecords = new ArrayList<>(); - for(int rowIndex = 0; rowIndex < batchSize; rowIndex ++) { - Record record = new Record(); - for(int columnIndex = 0; columnIndex < columnCount; columnIndex++) { - log.info("column: {}, type: {}", columnIndex, root.getFieldVectors().get(columnIndex)); - columnName = root.getVector(columnIndex).getField().getName().toLowerCase(); - record.set(columnName, this.getValue(root.getFieldVectors().get(columnIndex), rowIndex)); - } - multiRecords.add(record.getData()); - if(multiRecords.size() == BATCH_NUM) { - this.insertMultiData(commandConfig.getResultSchema(), multiRecords); - multiRecords.clear(); - } + /** + * Batch insert mode: accumulate records and batch insert. + */ + private void writeWithMultiInsert(VectorSchemaRoot root, int batchSize) { + List> multiRecords = new ArrayList<>(); + for (int rowIndex = 0; rowIndex < batchSize; rowIndex++) { + multiRecords.add(buildRecord(root, rowIndex)); + if (multiRecords.size() == BATCH_NUM) { + this.insertMultiData(commandConfig.getResultSchema(), multiRecords); + this.commitBatch(); + multiRecords.clear(); } - } else { - for(int rowIndex = 0; rowIndex < batchSize; rowIndex ++) { - Record record = new Record(); - for(int columnIndex = 0; columnIndex < columnCount; columnIndex++) { - log.info("column: {}, type: {}", columnIndex, root.getFieldVectors().get(columnIndex)); - columnName = root.getVector(columnIndex).getField().getName().toLowerCase(); - record.set(columnName, this.getValue(root.getFieldVectors().get(columnIndex), rowIndex)); - } - this.insertData(commandConfig.getResultSchema(), record.getData()); + } + if (!multiRecords.isEmpty()) { + this.insertMultiData(commandConfig.getResultSchema(), multiRecords); + this.commitBatch(); + } + } + + /** + * Single insert mode: insert row by row. + */ + private void writeWithSingleInsert(VectorSchemaRoot root, int batchSize) { + boolean autoCommit = true; + try { + autoCommit = connection.getAutoCommit(); + } catch (SQLException e) { + log.warn("Failed to get autoCommit status, assuming true", e); + } + + for (int rowIndex = 0; rowIndex < batchSize; rowIndex++) { + Map record = buildRecord(root, rowIndex); + this.insertData(commandConfig.getResultSchema(), record); + + if (!autoCommit && (rowIndex + 1) % BATCH_NUM == 0) { + this.commitBatch(); } } + if (!autoCommit && batchSize % BATCH_NUM != 0) { + this.commitBatch(); + } } + /** + * Commit transaction (if manual transaction is enabled). + * This is the core method that actually performs the commit. + */ + private void commitTransaction() { + try { + if (connection != null && !connection.getAutoCommit()) { + connection.commit(); + log.debug("Transaction committed"); + } + } catch (SQLException e) { + log.error("Failed to commit transaction", e); + throw new RuntimeException("Failed to commit transaction", e); + } + } + + /** + * Commit transaction per batch (internal use). + * Automatically called during batch insert, commits every 500 rows. + */ + private void commitBatch() { + commitTransaction(); + } + + /** + * Flush/commit transaction (public interface method). + * Implements Writer interface, allows external explicit call to flush uncommitted data. + */ @Override public void flush() { - // do nothing + commitTransaction(); } public void close() { - try{ - connection.close(); + if (connection == null) { + return; + } + + try { + // Rollback uncommitted transaction first if exists + if (!connection.isClosed() && !connection.getAutoCommit()) { + try { + connection.rollback(); + log.debug("Rolled back uncommitted transaction on close"); + } catch (SQLException e) { + log.warn("Failed to rollback transaction on close", e); + } + } + + if (!connection.isClosed()) { + connection.close(); + } } catch (SQLException e) { - log.error("database connection close error"); - throw new RuntimeException(e); + log.error("Database connection close error", e); + throw new RuntimeException("Failed to close database connection", e); } } @@ -227,11 +575,11 @@ private void createTable(Schema schema){ String createTableSql = this.buildCreateTableSql.apply(tableName, schema, partitionSpec); try (Statement stmt = connection.createStatement()){ stmt.executeUpdate(createTableSql); + log.info("Successfully created table: {}", tableName); } catch (SQLException e) { - log.error("create table sql:{} error: {}", createTableSql, e.getMessage()); - throw new RuntimeException(e); + log.error("Failed to create table {}: {}", tableName, e.getMessage(), e); + throw new RuntimeException("Failed to create table " + tableName + ": " + e.getMessage(), e); } - } private void validateTableName(String tableName) { @@ -246,28 +594,70 @@ private void validateTableName(String tableName) { private void dropTable() throws SQLException { validateTableName(tableName); + if (checkTableExists.apply(connection, tableName)) { + String sql = "DROP TABLE " + tableName; + try (Statement stmt = connection.createStatement()) { + stmt.execute(sql); + log.info("Successfully dropped table: {}", tableName); + } catch (SQLException e) { + log.error("Failed to drop table {}: {}", tableName, e.getMessage(), e); + throw e; + } + } + } - String sql = "DROP TABLE IF EXISTS " + tableName; + private void deleteAllRowOfTable() throws SQLException { + validateTableName(tableName); + String sql = "DELETE FROM " + tableName; + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { - preparedStatement.execute(); - log.info("Table {} dropped successfully.", tableName); + int deletedRows = preparedStatement.executeUpdate(); + log.info("Deleted {} rows from table {}", deletedRows, tableName); } catch (SQLException e) { - log.error("Failed to drop table {}: {}", tableName, e.getMessage()); + log.error("Failed to delete data from table {}: {}", tableName, e.getMessage(), e); throw e; } } - private void deleteAllRowOfTable() throws SQLException { + /** + * Delete rows for specific partition (partition overwrite) + * + * @param partitionSpec Partition specification map (e.g., {"dt": "20240101"}) + * @throws SQLException If deletion fails + */ + private void deletePartitionData(Map partitionSpec) throws SQLException { validateTableName(tableName); - - String sql = "DELETE FROM " + tableName; - + if (partitionSpec == null || partitionSpec.isEmpty()) { + log.warn("Partition spec is empty, skipping partition deletion"); + return; + } + + // Build WHERE clause for partition deletion + List conditions = new ArrayList<>(); + for (Map.Entry entry : partitionSpec.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + // Validate partition key name (basic validation) + if (!key.matches("^[a-zA-Z][a-zA-Z0-9_]*$")) { + throw new IllegalArgumentException("Invalid partition key: " + key); + } + conditions.add(key + " = ?"); + } + + String whereClause = String.join(" AND ", conditions); + String sql = "DELETE FROM " + tableName + " WHERE " + whereClause; + try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) { - int rowsDeleted = preparedStatement.executeUpdate(); - log.info("Number of rows deleted: {}", rowsDeleted); + int paramIndex = 1; + for (Map.Entry entry : partitionSpec.entrySet()) { + preparedStatement.setString(paramIndex++, entry.getValue()); + } + int deletedRows = preparedStatement.executeUpdate(); + log.info("Deleted {} rows from partition {} in table {}", deletedRows, partitionSpec, tableName); } catch (SQLException e) { - log.info("Failed to delete data from table {} : {}", tableName, e.getMessage()); + log.error("Failed to delete partition data from table {}: partition={}, error: {}", + tableName, partitionSpec, e.getMessage()); throw e; } } @@ -277,45 +667,86 @@ public void insertData(Schema arrowSchema, Map data) { try (PreparedStatement stmt = connection.prepareStatement(sql)){ stmt.executeUpdate(); } catch (SQLException e) { - log.error("insert data error: sql:\"{}\" error:\"{}\"", sql, e.getMessage()); - throw new RuntimeException(e); + log.error("Failed to insert data into table {}: {}", tableName, e.getMessage(), e); + throw new RuntimeException("Failed to insert data into table " + tableName + ": " + e.getMessage(), e); } } public void insertMultiData(Schema arrowSchema, List> multiData){ - SqlWithParams sp = this.buildMultiInsertSql.apply(tableName, arrowSchema, multiData, partitionSpec); + + // Build parameter type mapping: determine each parameter's type based on Arrow Schema + // Note: INTERVAL types skip parameter list (directly embedded in SQL), so parameter index and column index may not match + List paramTypes = new ArrayList<>(); + for (Field field : arrowSchema.getFields()) { + ArrowType fieldType = field.getType(); + if (!(fieldType instanceof ArrowType.Interval)) { + paramTypes.add(fieldType); + } + } + try (PreparedStatement ps = connection.prepareStatement(sp.sql);){ for (int i = 0; i < sp.params.size(); i++) { - ps.setObject(i + 1, sp.params.get(i)); + Object param = sp.params.get(i); + // For multi-row insert, parameter list contains all rows' parameters, use modulo to get corresponding type + ArrowType paramType = paramTypes.isEmpty() ? null : paramTypes.get(i % paramTypes.size()); + + if (param instanceof Float) { + ps.setFloat(i + 1, ((Float) param).floatValue()); + } else if (param instanceof Double) { + ps.setDouble(i + 1, ((Double) param).doubleValue()); + } else if (param instanceof BigDecimal && paramType instanceof ArrowType.Decimal) { + // DECIMAL type: use setBigDecimal() to ensure precision and scale are correctly passed + ps.setBigDecimal(i + 1, (BigDecimal) param); + } else if (paramType instanceof ArrowType.LargeUtf8) { + // CLOB type (LargeUtf8): must convert byte[] to String (UTF-8 decode) + String clobValue; + if (param instanceof byte[]) { + clobValue = new String((byte[]) param, StandardCharsets.UTF_8); + ps.setObject(i + 1, clobValue); + continue; + } + } + ps.setObject(i + 1, param); } - ps.executeUpdate(); + int rowsAffected = ps.executeUpdate(); + log.debug("Inserted {} rows into table {}", multiData.size(), tableName); } catch (SQLException e) { - log.error("insert data error: sql:\"{}\" error:\"{}\"", sp, e.getMessage()); - throw new RuntimeException(e); + log.error("Failed to insert {} rows into table {}: {}", multiData.size(), tableName, e.getMessage(), e); + throw new RuntimeException("Failed to insert data into table " + tableName + ": " + e.getMessage(), e); } } - - // create table when the table not exist private void preProcessing(String tableName){ - if(checkTableExists.apply(connection, tableName)) { - log.info("database table is exists, table name: {}", tableName); - log.info("trying dropping table {}", tableName); + boolean tableExists = checkTableExists.apply(connection, tableName); + + if (tableExists) { + // If partition spec exists, delete only partition data (partition overwrite) + if (partitionSpec != null && !partitionSpec.isEmpty()) { + try { + this.deletePartitionData(partitionSpec); + return; // Partition overwrite completed, no need to create table + } catch (SQLException e) { + log.error("Failed to delete partition data from table {}: partition={}, error: {}", + tableName, partitionSpec, e.getMessage(), e); + throw new RuntimeException("Failed to delete partition data from table " + tableName + + ", partition=" + partitionSpec + ": " + e.getMessage(), e); + } + } + + // No partition spec, delete entire table (full table overwrite) try { this.dropTable(); } catch (SQLException e) { - try { - this.deleteAllRowOfTable(); - } catch (SQLException ex) { - throw new RuntimeException(ex); - } + log.error("Failed to drop table {}: {}", tableName, e.getMessage(), e); } - - } else { - log.info("table {} no exists", tableName); } - createTable(commandConfig.getResultSchema()); + + if(!checkTableExists.apply(connection, tableName)) { + createTable(commandConfig.getResultSchema()); + } else if (partitionSpec == null || partitionSpec.isEmpty()) { + throw new RuntimeException("Cannot create table " + tableName + ": table still exists after drop attempt"); + } } public static class SqlWithParams { @@ -327,6 +758,14 @@ public SqlWithParams(String sql, List params) { this.sql = sql; this.params = params; } + + @Override + public String toString() { + return "SqlWithParams{" + + "sql='" + (sql.length() > 500 ? sql.substring(0, 500) + "..." : sql) + '\'' + + ", params=" + params.size() + " parameters" + + '}'; + } } } diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetContextTest.java b/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetContextTest.java new file mode 100644 index 0000000..b390855 --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseDoGetContextTest.java @@ -0,0 +1,469 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.secretflow.dataproxy.plugin.database.reader; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.secretflow.dataproxy.common.exceptions.DataproxyException; +import org.secretflow.dataproxy.plugin.database.config.*; +import org.secretflow.dataproxy.plugin.database.constant.DatabaseTypeEnum; +import org.secretflow.v1alpha1.common.Common; + +import java.sql.*; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +/** + * Unit test class for DatabaseDoGetContext. + * + * @author kongxiaoran + * @date 2025/11/07 + */ +@ExtendWith(MockitoExtension.class) +public class DatabaseDoGetContextTest { + + @Mock + private Connection mockConnection; + + @Mock + private Statement mockStatement; + + @Mock + private ResultSet mockResultSet; + + @Mock + private DatabaseMetaData mockDatabaseMetaData; + + @Mock + private ResultSetMetaData mockResultSetMetaData; + + @Mock + private ResultSet mockColumnsResultSet; + + private DatabaseConnectConfig dbConnectConfig; + private Function initDatabaseFunc; + private DatabaseDoGetContext.BuildQuerySqlFunc, String, String> buildQuerySqlFunc; + private Function jdbcType2ArrowType; + + @BeforeEach + public void setUp() throws SQLException { + dbConnectConfig = new DatabaseConnectConfig("user", "pass", "localhost:3306", "testdb"); + + // Set initDatabaseFunc + initDatabaseFunc = config -> mockConnection; + + // Set buildQuerySqlFunc + buildQuerySqlFunc = (tableName, columns, partition) -> { + StringBuilder sql = new StringBuilder("SELECT "); + sql.append(String.join(", ", columns)); + sql.append(" FROM ").append(tableName); + if (partition != null && !partition.isEmpty()) { + sql.append(" WHERE ").append(partition); + } + return sql.toString(); + }; + + // Set jdbcType2ArrowType + jdbcType2ArrowType = typeName -> { + if (typeName == null || typeName.isEmpty()) { + return ArrowType.Utf8.INSTANCE; + } + String upperType = typeName.toUpperCase(); + if (upperType.contains("INT")) { + return new ArrowType.Int(32, true); + } else if (upperType.contains("VARCHAR") || upperType.contains("CHAR")) { + return ArrowType.Utf8.INSTANCE; + } else if (upperType.contains("DECIMAL") || upperType.contains("NUMERIC")) { + return new ArrowType.Decimal(10, 2, 128); + } else { + return ArrowType.Utf8.INSTANCE; + } + }; + } + + @Test + public void testConstructor_WithTableQueryConfig() throws SQLException { + // Prepare test data + DatabaseTableConfig tableConfig = new DatabaseTableConfig( + "test_table", + "dt=20240101", + Arrays.asList( + Common.DataColumn.newBuilder().setName("id").setType("int32").build(), + Common.DataColumn.newBuilder().setName("name").setType("string").build() + ) + ); + + DatabaseTableQueryConfig queryConfig = new DatabaseTableQueryConfig(dbConnectConfig, tableConfig); + + // Set mock behavior + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockDatabaseMetaData.getColumns(any(), any(), eq("test_table"), any())) + .thenReturn(mockColumnsResultSet); + when(mockColumnsResultSet.next()).thenReturn(true, true, false); + when(mockColumnsResultSet.getString("COLUMN_NAME")).thenReturn("id", "name"); + when(mockColumnsResultSet.getString("TYPE_NAME")).thenReturn("INT", "VARCHAR"); + when(mockColumnsResultSet.getInt("COLUMN_SIZE")).thenReturn(0, 0); + when(mockColumnsResultSet.getInt("DECIMAL_DIGITS")).thenReturn(-1, -1); + + DatabaseDoGetContext context = new DatabaseDoGetContext( + queryConfig, + initDatabaseFunc, + buildQuerySqlFunc, + jdbcType2ArrowType + ); + + assertNotNull(context); + assertNotNull(context.getSchema()); + assertEquals("test_table", context.getTableName()); + assertNotNull(context.getResultSet()); + assertNotNull(context.getDatabaseMetaData()); + } + + @Test + public void testConstructor_WithScqlCommandJobConfig() throws SQLException { + ScqlCommandJobConfig sqlConfig = new ScqlCommandJobConfig( + dbConnectConfig, + "SELECT * FROM test_table" + ); + + // Set mock behavior - SQL query scenario + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetaData); + when(mockResultSetMetaData.getColumnCount()).thenReturn(2); + when(mockResultSetMetaData.getColumnName(1)).thenReturn("id"); + when(mockResultSetMetaData.getColumnName(2)).thenReturn("name"); + when(mockResultSetMetaData.getColumnTypeName(1)).thenReturn("INT"); + when(mockResultSetMetaData.getColumnTypeName(2)).thenReturn("VARCHAR"); + when(mockResultSetMetaData.getPrecision(1)).thenReturn(0); + when(mockResultSetMetaData.getPrecision(2)).thenReturn(0); + when(mockResultSetMetaData.getScale(1)).thenReturn(-1); + when(mockResultSetMetaData.getScale(2)).thenReturn(-1); + + DatabaseDoGetContext context = new DatabaseDoGetContext( + sqlConfig, + initDatabaseFunc, + buildQuerySqlFunc, + jdbcType2ArrowType + ); + + assertNotNull(context); + assertNotNull(context.getSchema()); + assertNotNull(context.getResultSet()); + } + + @Test + public void testConstructor_WithUnsupportedConfig() { + // Create an unsupported config type + DatabaseCommandConfig unsupportedConfig = new DatabaseCommandConfig( + dbConnectConfig, DatabaseTypeEnum.TABLE, new Object() + ) { + @Override + public String taskRunSQL() { + return ""; + } + + @Override + public Schema getResultSchema() { + return null; + } + }; + + assertThrows(DataproxyException.class, () -> { + new DatabaseDoGetContext( + unsupportedConfig, + initDatabaseFunc, + buildQuerySqlFunc, + jdbcType2ArrowType + ); + }); + } + + @Test + public void testConstructor_WithSQLException() throws SQLException { + DatabaseTableConfig tableConfig = new DatabaseTableConfig( + "test_table", + null, + Arrays.asList( + Common.DataColumn.newBuilder().setName("id").setType("int32").build() + ) + ); + + DatabaseTableQueryConfig queryConfig = new DatabaseTableQueryConfig(dbConnectConfig, tableConfig); + + // Mock SQLException + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenThrow(new SQLException("Database error")); + + assertThrows(DataproxyException.class, () -> { + new DatabaseDoGetContext( + queryConfig, + initDatabaseFunc, + buildQuerySqlFunc, + jdbcType2ArrowType + ); + }); + } + + @Test + public void testGetTaskConfigs() throws SQLException { + DatabaseTableConfig tableConfig = new DatabaseTableConfig( + "test_table", + null, + Arrays.asList( + Common.DataColumn.newBuilder().setName("id").setType("int32").build() + ) + ); + + DatabaseTableQueryConfig queryConfig = new DatabaseTableQueryConfig(dbConnectConfig, tableConfig); + + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockDatabaseMetaData.getColumns(any(), any(), eq("test_table"), any())) + .thenReturn(mockColumnsResultSet); + when(mockColumnsResultSet.next()).thenReturn(true, false); + when(mockColumnsResultSet.getString("COLUMN_NAME")).thenReturn("id"); + when(mockColumnsResultSet.getString("TYPE_NAME")).thenReturn("INT"); + when(mockColumnsResultSet.getInt("COLUMN_SIZE")).thenReturn(0); + when(mockColumnsResultSet.getInt("DECIMAL_DIGITS")).thenReturn(-1); + + DatabaseDoGetContext context = new DatabaseDoGetContext( + queryConfig, + initDatabaseFunc, + buildQuerySqlFunc, + jdbcType2ArrowType + ); + + List taskConfigs = context.getTaskConfigs(); + assertNotNull(taskConfigs); + assertEquals(1, taskConfigs.size()); + assertNotNull(taskConfigs.get(0)); + } + + @Test + public void testClose() throws SQLException { + DatabaseTableConfig tableConfig = new DatabaseTableConfig( + "test_table", + null, + Arrays.asList( + Common.DataColumn.newBuilder().setName("id").setType("int32").build() + ) + ); + + DatabaseTableQueryConfig queryConfig = new DatabaseTableQueryConfig(dbConnectConfig, tableConfig); + + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockDatabaseMetaData.getColumns(any(), any(), eq("test_table"), any())) + .thenReturn(mockColumnsResultSet); + when(mockColumnsResultSet.next()).thenReturn(true, false); + when(mockColumnsResultSet.getString("COLUMN_NAME")).thenReturn("id"); + when(mockColumnsResultSet.getString("TYPE_NAME")).thenReturn("INT"); + when(mockColumnsResultSet.getInt("COLUMN_SIZE")).thenReturn(0); + when(mockColumnsResultSet.getInt("DECIMAL_DIGITS")).thenReturn(-1); + + doNothing().when(mockResultSet).close(); + doNothing().when(mockStatement).close(); + doNothing().when(mockConnection).close(); + + DatabaseDoGetContext context = new DatabaseDoGetContext( + queryConfig, + initDatabaseFunc, + buildQuerySqlFunc, + jdbcType2ArrowType + ); + + assertDoesNotThrow(() -> { + context.close(); + }); + + verify(mockResultSet, times(1)).close(); + verify(mockStatement, times(1)).close(); + verify(mockConnection, times(1)).close(); + } + + @Test + public void testClose_WithSQLException() throws SQLException { + DatabaseTableConfig tableConfig = new DatabaseTableConfig( + "test_table", + null, + Arrays.asList( + Common.DataColumn.newBuilder().setName("id").setType("int32").build() + ) + ); + + DatabaseTableQueryConfig queryConfig = new DatabaseTableQueryConfig(dbConnectConfig, tableConfig); + + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockDatabaseMetaData.getColumns(any(), any(), eq("test_table"), any())) + .thenReturn(mockColumnsResultSet); + when(mockColumnsResultSet.next()).thenReturn(true, false); + when(mockColumnsResultSet.getString("COLUMN_NAME")).thenReturn("id"); + when(mockColumnsResultSet.getString("TYPE_NAME")).thenReturn("INT"); + when(mockColumnsResultSet.getInt("COLUMN_SIZE")).thenReturn(0); + when(mockColumnsResultSet.getInt("DECIMAL_DIGITS")).thenReturn(-1); + + doThrow(new SQLException("Close error")).when(mockResultSet).close(); + + DatabaseDoGetContext context = new DatabaseDoGetContext( + queryConfig, + initDatabaseFunc, + buildQuerySqlFunc, + jdbcType2ArrowType + ); + + assertThrows(RuntimeException.class, () -> { + context.close(); + }); + } + + @Test + public void testClose_WithConnectionFailure() throws SQLException { + DatabaseTableConfig tableConfig = new DatabaseTableConfig( + "test_table", + null, + Arrays.asList( + Common.DataColumn.newBuilder().setName("id").setType("int32").build() + ) + ); + + DatabaseTableQueryConfig queryConfig = new DatabaseTableQueryConfig(dbConnectConfig, tableConfig); + + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockDatabaseMetaData.getColumns(any(), any(), eq("test_table"), any())) + .thenReturn(mockColumnsResultSet); + when(mockColumnsResultSet.next()).thenReturn(true, false); + when(mockColumnsResultSet.getString("COLUMN_NAME")).thenReturn("id"); + when(mockColumnsResultSet.getString("TYPE_NAME")).thenReturn("INT"); + when(mockColumnsResultSet.getInt("COLUMN_SIZE")).thenReturn(0); + when(mockColumnsResultSet.getInt("DECIMAL_DIGITS")).thenReturn(-1); + + doNothing().when(mockResultSet).close(); + doNothing().when(mockStatement).close(); + doThrow(new SQLException("Connection close error")).when(mockConnection).close(); + + DatabaseDoGetContext context = new DatabaseDoGetContext( + queryConfig, + initDatabaseFunc, + buildQuerySqlFunc, + jdbcType2ArrowType + ); + + assertThrows(RuntimeException.class, () -> { + context.close(); + }); + } + + @Test + public void testInitArrowSchemaFromResultSet_WithDecimalType() throws SQLException { + ScqlCommandJobConfig sqlConfig = new ScqlCommandJobConfig( + dbConnectConfig, + "SELECT * FROM test_table" + ); + + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetaData); + when(mockResultSetMetaData.getColumnCount()).thenReturn(1); + when(mockResultSetMetaData.getColumnName(1)).thenReturn("price"); + when(mockResultSetMetaData.getColumnTypeName(1)).thenReturn("DECIMAL"); + when(mockResultSetMetaData.getPrecision(1)).thenReturn(10); + when(mockResultSetMetaData.getScale(1)).thenReturn(2); + + // Update jdbcType2ArrowType to handle DECIMAL + Function decimalTypeFunc = typeName -> { + if (typeName != null && typeName.toUpperCase().contains("DECIMAL")) { + return new ArrowType.Decimal(10, 2, 128); + } + return ArrowType.Utf8.INSTANCE; + }; + + DatabaseDoGetContext context = new DatabaseDoGetContext( + sqlConfig, + initDatabaseFunc, + buildQuerySqlFunc, + decimalTypeFunc + ); + + assertNotNull(context.getSchema()); + assertEquals(1, context.getSchema().getFields().size()); + assertTrue(context.getSchema().getFields().get(0).getType() instanceof ArrowType.Decimal); + } + + @Test + public void testInitArrowSchemaFromColumns_WithTimeType() throws SQLException { + DatabaseTableConfig tableConfig = new DatabaseTableConfig( + "test_table", + null, + Arrays.asList( + Common.DataColumn.newBuilder().setName("created_at").setType("timestamp").build() + ) + ); + + DatabaseTableQueryConfig queryConfig = new DatabaseTableQueryConfig(dbConnectConfig, tableConfig); + + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(anyString())).thenReturn(mockResultSet); + when(mockDatabaseMetaData.getColumns(any(), any(), eq("test_table"), any())) + .thenReturn(mockColumnsResultSet); + when(mockColumnsResultSet.next()).thenReturn(true, false); + when(mockColumnsResultSet.getString("COLUMN_NAME")).thenReturn("created_at"); + when(mockColumnsResultSet.getString("TYPE_NAME")).thenReturn("TIMESTAMP"); + when(mockColumnsResultSet.getInt("COLUMN_SIZE")).thenReturn(0); + when(mockColumnsResultSet.getInt("DECIMAL_DIGITS")).thenReturn(3); // Millisecond precision + + // Update jdbcType2ArrowType to handle TIMESTAMP + Function timestampTypeFunc = typeName -> { + if (typeName != null && typeName.toUpperCase().contains("TIMESTAMP")) { + return new ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.MILLISECOND, null); + } + return ArrowType.Utf8.INSTANCE; + }; + + DatabaseDoGetContext context = new DatabaseDoGetContext( + queryConfig, + initDatabaseFunc, + buildQuerySqlFunc, + timestampTypeFunc + ); + + assertNotNull(context.getSchema()); + assertEquals(1, context.getSchema().getFields().size()); + } +} + diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseReaderTest.java b/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseReaderTest.java new file mode 100644 index 0000000..27be94f --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseReaderTest.java @@ -0,0 +1,271 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.secretflow.dataproxy.plugin.database.reader; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.secretflow.dataproxy.plugin.database.config.TaskConfig; + +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** + * Unit test class for DatabaseReader. + * + * @author kongxiaoran + * @date 2025/11/07 + */ +@ExtendWith(MockitoExtension.class) +public class DatabaseReaderTest { + + @Mock + private DatabaseDoGetContext mockContext; + + @Mock + private DatabaseDoGetTaskContext mockTaskContext; + + @Mock + private ResultSet mockResultSet; + + @Mock + private DatabaseMetaData mockDatabaseMetaData; + + private BufferAllocator allocator; + private TaskConfig taskConfig; + private Schema testSchema; + private List readersToCleanup = new ArrayList<>(); + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(); + testSchema = new Schema(java.util.Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("name", FieldType.nullable(ArrowType.Utf8.INSTANCE), null) + )); + + // Use lenient() to avoid UnnecessaryStubbingException + lenient().when(mockContext.getSchema()).thenReturn(testSchema); + lenient().when(mockContext.getResultSet()).thenReturn(mockResultSet); + lenient().when(mockContext.getDatabaseMetaData()).thenReturn(mockDatabaseMetaData); + lenient().when(mockContext.getTableName()).thenReturn("test_table"); + taskConfig = new TaskConfig(mockContext, 0); + readersToCleanup.clear(); + } + + @AfterEach + public void tearDown() { + // Clean up all created DatabaseReader instances, ensure background threads are properly closed + for (DatabaseReader reader : readersToCleanup) { + try { + reader.closeReadSource(); + // Close the reader completely to release VectorSchemaRoot + reader.close(); + } catch (Exception e) { + // Ignore exceptions during cleanup + } + } + readersToCleanup.clear(); + + // Close allocator - this will detect any memory leaks + if (allocator != null) { + allocator.close(); + } + } + + @Test + public void testConstructor() { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + readersToCleanup.add(reader); + assertNotNull(reader); + } + + @Test + public void testReadSchema() { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + readersToCleanup.add(reader); + Schema schema = reader.readSchema(); + assertNotNull(schema); + assertEquals(testSchema, schema); + } + + @Test + public void testBytesRead() { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + readersToCleanup.add(reader); + assertEquals(0, reader.bytesRead()); + } + + @Test + public void testLoadNextBatch_WithError() { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + readersToCleanup.add(reader); + RuntimeException error = new RuntimeException("Test error"); + taskConfig.setError(error); + + assertThrows(RuntimeException.class, () -> { + reader.loadNextBatch(); + }); + } + + @Test + @Timeout(value = 5, unit = TimeUnit.SECONDS) + public void testLoadNextBatch_NoError() throws Exception { + // Use mock DatabaseDoGetTaskContext to avoid starting real background thread + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + readersToCleanup.add(reader); + taskConfig.setError(null); + + // Use reflection to set mock dbDoGetTaskContext + java.lang.reflect.Field field = DatabaseReader.class.getDeclaredField("dbDoGetTaskContext"); + field.setAccessible(true); + field.set(reader, mockTaskContext); + + // Mock hasNext returns false, indicating no more data + when(mockTaskContext.hasNext()).thenReturn(false); + + boolean result = reader.loadNextBatch(); + assertFalse(result); + } + + @Test + public void testCloseReadSource_WithTaskContext() throws Exception { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + readersToCleanup.add(reader); + // Use reflection to set dbDoGetTaskContext + java.lang.reflect.Field field = DatabaseReader.class.getDeclaredField("dbDoGetTaskContext"); + field.setAccessible(true); + field.set(reader, mockTaskContext); + + doNothing().when(mockTaskContext).close(); + + reader.closeReadSource(); + + verify(mockTaskContext, times(1)).close(); + } + + @Test + public void testCloseReadSource_WithoutTaskContext() { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + assertDoesNotThrow(() -> { + reader.closeReadSource(); + }); + } + + @Test + public void testCloseReadSource_WithInterruptedException() throws Exception { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + java.lang.reflect.Field field = DatabaseReader.class.getDeclaredField("dbDoGetTaskContext"); + field.setAccessible(true); + field.set(reader, mockTaskContext); + + doThrow(new InterruptedException("Test interrupt")).when(mockTaskContext).close(); + + reader.closeReadSource(); + + // Verify thread interrupt status is set + assertTrue(Thread.currentThread().isInterrupted()); + // Clear interrupt status to avoid affecting subsequent tests + Thread.interrupted(); + } + + @Test + public void testCloseReadSource_WithException() throws Exception { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + java.lang.reflect.Field field = DatabaseReader.class.getDeclaredField("dbDoGetTaskContext"); + field.setAccessible(true); + field.set(reader, mockTaskContext); + + Exception testException = new RuntimeException("Test exception"); + doThrow(testException).when(mockTaskContext).close(); + + assertThrows(RuntimeException.class, () -> { + reader.closeReadSource(); + }); + } + + @Test + @Timeout(value = 5, unit = TimeUnit.SECONDS) + public void testPrepare() throws Exception { + DatabaseReader reader = new DatabaseReader(allocator, taskConfig); + readersToCleanup.add(reader); + + // Verify initial state is null + java.lang.reflect.Field field = DatabaseReader.class.getDeclaredField("dbDoGetTaskContext"); + field.setAccessible(true); + assertNull(field.get(reader), "dbDoGetTaskContext should be null initially"); + + // Mock ResultSet behavior for DatabaseRecordReader + when(mockResultSet.next()).thenReturn(false); // No rows, will immediately finish + + // Call prepare() - it should create and start DatabaseDoGetTaskContext + // Even though it starts a background thread, we can verify the object was created + reader.prepare(); + + // Verify DatabaseDoGetTaskContext was created + Object dbDoGetTaskContext = field.get(reader); + assertNotNull(dbDoGetTaskContext, "dbDoGetTaskContext should be created after prepare()"); + assertTrue(dbDoGetTaskContext instanceof DatabaseDoGetTaskContext, + "dbDoGetTaskContext should be an instance of DatabaseDoGetTaskContext"); + + // Wait a bit for the background thread to finish (since ResultSet.next() returns false) + Thread.sleep(100); + + // Verify the background thread has started (hasNext should eventually become false) + DatabaseDoGetTaskContext context = (DatabaseDoGetTaskContext) dbDoGetTaskContext; + // The hasNext flag should eventually become false after the reader finishes + // We wait a bit more to ensure the thread has processed + int attempts = 0; + while (context.hasNext() && attempts < 50) { + Thread.sleep(50); + attempts++; + } + // Context is managed by reader, no need to close separately + + // Clean up: close the reader to stop background thread and release resources + reader.closeReadSource(); + + // Wait for executor service to shutdown completely + // DatabaseDoGetTaskContext.close() calls executorService.shutdown() but doesn't wait + // We need to wait a bit to ensure the background thread has finished and released resources + Thread.sleep(500); + + // Close the reader completely to release VectorSchemaRoot + // This is necessary because VectorSchemaRoot is managed by ArrowReader + reader.close(); + + // Remove from cleanup list since we've already closed it + readersToCleanup.remove(reader); + } +} + diff --git a/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseRecordSenderTest.java b/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseRecordSenderTest.java new file mode 100644 index 0000000..edd38c1 --- /dev/null +++ b/dataproxy-plugins/dataproxy-plugin-database/src/test/java/org/secretflow/dataproxy/plugin/database/reader/DatabaseRecordSenderTest.java @@ -0,0 +1,387 @@ +/* + * Copyright 2025 Ant Group Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.secretflow.dataproxy.plugin.database.reader; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.secretflow.dataproxy.plugin.database.utils.Record; + +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.util.concurrent.LinkedBlockingQueue; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit test class for DatabaseRecordSender. + * + * @author kongxiaoran + * @date 2025/11/07 + */ +@ExtendWith(MockitoExtension.class) +public class DatabaseRecordSenderTest { + + @Mock + private DatabaseMetaData mockMetaData; + + @Mock + private ResultSet mockResultSet; + + private RootAllocator allocator; + private VectorSchemaRoot root; + private Schema schema; + private LinkedBlockingQueue recordQueue; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(); + schema = new Schema(java.util.Arrays.asList( + new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("name", FieldType.nullable(ArrowType.Utf8.INSTANCE), null), + new Field("price", FieldType.nullable(new ArrowType.Decimal(10, 2, 128)), null), + new Field("is_active", FieldType.nullable(ArrowType.Bool.INSTANCE), null) + )); + + root = VectorSchemaRoot.create(schema, allocator); + recordQueue = new LinkedBlockingQueue<>(1000); + } + + @Test + public void testConstructor() { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + assertNotNull(sender); + } + + @Test + public void testToArrowVector_WithIntAndString() throws InterruptedException { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + record.set("id", 1); + record.set("name", "test"); + + sender.toArrowVector(record, root, 0); + + // Verify data has been set + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + + assertEquals(1, idVector.get(0)); + // VarCharVector.get() returns byte[], need to convert to String + assertEquals("test", new String(nameVector.get(0))); + } + + @Test + public void testToArrowVector_WithNullValue() throws InterruptedException { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + record.set("id", null); + record.set("name", "test"); + + sender.toArrowVector(record, root, 0); + + // Verify null value has been correctly set + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + + assertTrue(idVector.isNull(0)); + // VarCharVector.get() returns byte[], need to convert to String + assertEquals("test", new String(nameVector.get(0))); + } + + @Test + public void testToArrowVector_WithMissingColumn() throws InterruptedException { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + record.set("id", 1); + // name column missing, should be ignored + + assertDoesNotThrow(() -> { + sender.toArrowVector(record, root, 0); + }); + + IntVector idVector = (IntVector) root.getVector("id"); + assertEquals(1, idVector.get(0)); + } + + @Test + public void testToArrowVector_WithBoolean() throws InterruptedException { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + record.set("is_active", true); + + sender.toArrowVector(record, root, 0); + + BitVector boolVector = (BitVector) root.getVector("is_active"); + assertEquals(1, boolVector.get(0)); // true = 1 + } + + @Test + public void testToArrowVector_WithDecimal() throws InterruptedException { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + java.math.BigDecimal expectedValue = java.math.BigDecimal.valueOf(99.99); + record.set("price", expectedValue); + + sender.toArrowVector(record, root, 0); + + DecimalVector decimalVector = (DecimalVector) root.getVector("price"); + assertNotNull(decimalVector); + assertFalse(decimalVector.isNull(0)); + + // Verify the decimal value is correctly set + java.math.BigDecimal actualValue = decimalVector.getObject(0); + assertEquals(expectedValue, actualValue); + } + + @Test + public void testIsOver_WithLastLine() { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + record.setLast(true); + + assertTrue(sender.isOver(record)); + } + + @Test + public void testIsOver_WithoutLastLine() { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + record.setLast(false); + + assertFalse(sender.isOver(record)); + } + + @Test + public void testPutOver() throws InterruptedException { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + sender.putOver(); + + // Poll record from queue and verify + Record lastRecord = recordQueue.poll(); + assertNotNull(lastRecord, "Queue should contain a record after putOver()"); + assertTrue(lastRecord.isLastLine(), "Last record should have isLast flag set to true"); + assertNotNull(lastRecord.getData()); + assertTrue(lastRecord.getData().isEmpty(), "Last record should have no data"); + } + + @Test + public void testEqualsIgnoreCase() { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + assertTrue(sender.equalsIgnoreCase("TEST", "test")); + assertTrue(sender.equalsIgnoreCase("test", "TEST")); + assertTrue(sender.equalsIgnoreCase("Test", "test")); + assertTrue(sender.equalsIgnoreCase(null, null)); + assertFalse(sender.equalsIgnoreCase("test", "other")); + assertFalse(sender.equalsIgnoreCase(null, "test")); + assertFalse(sender.equalsIgnoreCase("test", null)); + } + + @Test + public void testInitRecordColumn2FieldMap() throws InterruptedException { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + record.set("id", 1); + record.set("name", "test"); + + // First call will initialize mapping + sender.toArrowVector(record, root, 0); + + // Second call should use already initialized mapping + Record record2 = new Record(); + record2.set("id", 2); + record2.set("name", "test2"); + + assertDoesNotThrow(() -> { + sender.toArrowVector(record2, root, 1); + }); + } + + @Test + public void testInitRecordColumn2FieldMap_WithNullRoot() { + // AbstractSender constructor calls preAllocate(), if root is null will throw NPE + // This is expected behavior, null root should not be allowed + assertThrows(NullPointerException.class, () -> { + new DatabaseRecordSender( + 100, + recordQueue, + null, + "test_table", + mockMetaData, + mockResultSet + ); + }); + } + + @Test + public void testToArrowVector_WithUnsupportedType() throws InterruptedException { + // Create a schema containing unsupported type + Schema unsupportedSchema = new Schema(java.util.Arrays.asList( + new Field("unknown", FieldType.nullable(ArrowType.Null.INSTANCE), null) + )); + + VectorSchemaRoot unsupportedRoot = VectorSchemaRoot.create(unsupportedSchema, allocator); + LinkedBlockingQueue queue = new LinkedBlockingQueue<>(100); + + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + queue, + unsupportedRoot, + "test_table", + mockMetaData, + mockResultSet + ); + + Record record = new Record(); + record.set("unknown", "value"); + + // Should not throw exception, but will log warning + assertDoesNotThrow(() -> { + sender.toArrowVector(record, unsupportedRoot, 0); + }); + } + + @Test + public void testToArrowVector_MultipleRows() throws InterruptedException { + DatabaseRecordSender sender = new DatabaseRecordSender( + 100, + recordQueue, + root, + "test_table", + mockMetaData, + mockResultSet + ); + + // Set multiple rows of data + for (int i = 0; i < 3; i++) { + Record record = new Record(); + record.set("id", i); + record.set("name", "test" + i); + + sender.toArrowVector(record, root, i); + } + + // Need to set rowCount to correctly read data + root.setRowCount(3); + + // Verify all rows have been set + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + + assertEquals(3, root.getRowCount()); + assertEquals(0, idVector.get(0)); + assertEquals(1, idVector.get(1)); + assertEquals(2, idVector.get(2)); + } +} + diff --git a/dataproxy-plugins/pom.xml b/dataproxy-plugins/pom.xml index 3a420e1..cfad689 100644 --- a/dataproxy-plugins/pom.xml +++ b/dataproxy-plugins/pom.xml @@ -14,6 +14,7 @@ dataproxy-plugin-odps dataproxy-plugin-database dataproxy-plugin-hive + dataproxy-plugin-dameng dataproxy-plugins diff --git a/dataproxy-server/pom.xml b/dataproxy-server/pom.xml index c233c51..7d8c8b5 100644 --- a/dataproxy-server/pom.xml +++ b/dataproxy-server/pom.xml @@ -55,6 +55,11 @@ dataproxy-plugin-hive + + org.secretflow + dataproxy-plugin-dameng + + org.slf4j slf4j-api diff --git a/pom.xml b/pom.xml index 1dfdfc2..7d6cead 100644 --- a/pom.xml +++ b/pom.xml @@ -81,11 +81,20 @@ 2.1.7 5.11.4 1.14.5 - + 1.19.7 + + + org.testcontainers + testcontainers-bom + ${testcontainers.version} + pom + import + + org.apache.arrow @@ -422,6 +431,11 @@ dataproxy-plugin-hive ${project.version} + + org.secretflow + dataproxy-plugin-dameng + ${project.version} +