diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index f6f17770f..f81233ec3 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -20,6 +20,7 @@ import io.netty.util.concurrent.DefaultThreadFactory; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Properties; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -180,19 +181,39 @@ public Properties getClientInfo() { @Override public void close() throws SQLException { - clientHandler.close(); - if (executorService != null) { - executorService.shutdown(); + Exception topLevelException = null; + try { + if (executorService != null) { + executorService.shutdown(); + } + } catch (final Exception e) { + topLevelException = e; + } + ArrayList closeables = new ArrayList<>(statementMap.values()); + closeables.add(clientHandler); + closeables.addAll(allocator.getChildAllocators()); + closeables.add(allocator); + try { + AutoCloseables.close(closeables); + } catch (final Exception e) { + if (topLevelException == null) { + topLevelException = e; + } else { + topLevelException.addSuppressed(e); + } } - try { - AutoCloseables.close(clientHandler); - allocator.getChildAllocators().forEach(AutoCloseables::closeNoChecked); - AutoCloseables.close(allocator); - super.close(); } catch (final Exception e) { - throw AvaticaConnection.HELPER.createException(e.getMessage(), e); + if (topLevelException == null) { + topLevelException = e; + } else { + topLevelException.addSuppressed(e); + } + } + if (topLevelException != null) { + throw AvaticaConnection.HELPER.createException( + topLevelException.getMessage(), topLevelException); } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java index 46762f331..dbedbe9d3 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java @@ -17,6 +17,7 @@ package org.apache.arrow.driver.jdbc; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -27,6 +28,7 @@ import java.sql.Driver; import java.sql.DriverManager; import java.sql.SQLException; +import java.sql.Statement; import java.util.Map; import java.util.Properties; import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; @@ -660,4 +662,40 @@ public String visit(String value) { assertEquals(catalog, actualCatalog); } } + + @Test + public void testStatementsClosedOnConnectionClose() throws Exception { + // create a connection + final Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put( + ArrowFlightConnectionProperty.PORT.camelName(), FLIGHT_SERVER_TEST_EXTENSION.getPort()); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.put("useEncryption", false); + + Connection connection = + DriverManager.getConnection( + "jdbc:arrow-flight-sql://" + + FLIGHT_SERVER_TEST_EXTENSION.getHost() + + ":" + + FLIGHT_SERVER_TEST_EXTENSION.getPort(), + properties); + + // create some statements + int numStatements = 3; + Statement[] statements = new Statement[numStatements]; + for (int i = 0; i < numStatements; i++) { + statements[i] = connection.createStatement(); + assertFalse(statements[i].isClosed()); + } + + // close the connection + connection.close(); + + // assert the statements are closed + for (int i = 0; i < numStatements; i++) { + assertTrue(statements[i].isClosed()); + } + } }