diff --git a/duckdb_java.def b/duckdb_java.def index 170448f23..68ff3031b 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -29,6 +29,7 @@ Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1destroy_1db_1ref Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1create_1extension_1type Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1disconnect Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute +Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute_1pending Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1cast_1result_1to_1strings Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch_1size @@ -38,11 +39,13 @@ Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1catalog Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1profiling_1information Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1schema Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1interrupt +Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1pending_1query Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepare Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepared_1statement_1meta Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1result_1meta Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1progress Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release +Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release_1pending Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1auto_1commit Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1catalog Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema diff --git a/duckdb_java.exp b/duckdb_java.exp index 71158e9ec..6b6cb687d 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -26,6 +26,7 @@ _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1destroy_1db_1ref _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1create_1extension_1type _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1disconnect _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute +_Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute_1pending _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1cast_1result_1to_1strings _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch_1size @@ -35,11 +36,13 @@ _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1catalog _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1profiling_1information _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1schema _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1interrupt +_Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1pending_1query _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepare _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepared_1statement_1meta _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1result_1meta _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1progress _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release +_Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release_1pending _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1auto_1commit _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1catalog _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema diff --git a/duckdb_java.map b/duckdb_java.map index fd4da82e9..7ed2d7233 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -28,6 +28,7 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1create_1extension_1type; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1disconnect; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute; + Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute_1pending; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1cast_1result_1to_1strings; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1fetch_1size; @@ -37,11 +38,13 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1profiling_1information; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1get_1schema; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1interrupt; + Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1pending_1query; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepare; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepared_1statement_1meta; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1result_1meta; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1progress; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release; + Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release_1pending; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1auto_1commit; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1catalog; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema; diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index ae1811a95..0a5cbcce4 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -212,24 +212,54 @@ jobject _duckdb_jdbc_prepare(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArr } } - auto stmt_ref = new StatementHolder(); + auto stmt_ref = make_uniq(); stmt_ref->stmt = conn_ref->Prepare(std::move(statements.back())); if (stmt_ref->stmt->HasError()) { string error_msg = string(stmt_ref->stmt->GetError()); stmt_ref->stmt = nullptr; - - // No success, so it must be deleted - delete stmt_ref; ThrowJNI(env, error_msg.c_str()); + return nullptr; + } + return env->NewDirectByteBuffer(stmt_ref.release(), 0); +} - // Just return control flow back to JVM, as an Exception is pending anyway +jobject _duckdb_jdbc_pending_query(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArray query_j) { + auto conn_ref = get_connection(env, conn_ref_buf); + if (!conn_ref) { return nullptr; } - return env->NewDirectByteBuffer(stmt_ref, 0); + + auto query = jbyteArray_to_string(env, query_j); + + auto statements = conn_ref->ExtractStatements(query.c_str()); + if (statements.empty()) { + throw InvalidInputException("No statements to execute."); + } + + // if there are multiple statements, we directly execute the statements besides the last one + // we only return the result of the last statement to the user, unless one of the previous statements fails + for (idx_t i = 0; i + 1 < statements.size(); i++) { + auto res = conn_ref->Query(std::move(statements[i])); + if (res->HasError()) { + res->ThrowError(); + } + } + + Value result; + bool stream_results = + conn_ref->context->TryGetCurrentSetting("jdbc_stream_results", result) ? result.GetValue() : false; + QueryParameters query_parameters; + query_parameters.output_type = + stream_results ? QueryResultOutputType::ALLOW_STREAMING : QueryResultOutputType::FORCE_MATERIALIZED; + + auto pending_ref = make_uniq(); + pending_ref->pending = conn_ref->PendingQuery(std::move(statements.back()), query_parameters); + + return env->NewDirectByteBuffer(pending_ref.release(), 0); } jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectArray params) { - auto stmt_ref = (StatementHolder *)env->GetDirectBufferAddress(stmt_ref_buf); + auto stmt_ref = reinterpret_cast(env->GetDirectBufferAddress(stmt_ref_buf)); if (!stmt_ref) { throw InvalidInputException("Invalid statement"); } @@ -269,21 +299,50 @@ jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectA return env->NewDirectByteBuffer(res_ref.release(), 0); } +jobject _duckdb_jdbc_execute_pending(JNIEnv *env, jclass, jobject pending_ref_buf) { + auto pending_ref = reinterpret_cast(env->GetDirectBufferAddress(pending_ref_buf)); + if (!pending_ref) { + throw InvalidInputException("Invalid pending query"); + } + + auto res_ref = make_uniq(); + res_ref->res = pending_ref->pending->Execute(); + if (res_ref->res->HasError()) { + std::string error_msg = std::string(res_ref->res->GetError()); + duckdb::ExceptionType error_type = res_ref->res->GetErrorType(); + res_ref->res = nullptr; + jclass exc_type = duckdb::ExceptionType::INTERRUPT == error_type ? J_SQLTimeoutException : J_SQLException; + env->ThrowNew(exc_type, error_msg.c_str()); + return nullptr; + } + return env->NewDirectByteBuffer(res_ref.release(), 0); +} + void _duckdb_jdbc_release(JNIEnv *env, jclass, jobject stmt_ref_buf) { if (nullptr == stmt_ref_buf) { return; } - auto stmt_ref = (StatementHolder *)env->GetDirectBufferAddress(stmt_ref_buf); + auto stmt_ref = reinterpret_cast(env->GetDirectBufferAddress(stmt_ref_buf)); if (stmt_ref) { delete stmt_ref; } } +void _duckdb_jdbc_release_pending(JNIEnv *env, jclass, jobject pending_ref_buf) { + if (nullptr == pending_ref_buf) { + return; + } + auto pending_ref = reinterpret_cast(env->GetDirectBufferAddress(pending_ref_buf)); + if (pending_ref) { + delete pending_ref; + } +} + void _duckdb_jdbc_free_result(JNIEnv *env, jclass, jobject res_ref_buf) { if (nullptr == res_ref_buf) { return; } - auto res_ref = (ResultHolder *)env->GetDirectBufferAddress(res_ref_buf); + auto res_ref = reinterpret_cast(env->GetDirectBufferAddress(res_ref_buf)); if (res_ref) { delete res_ref; } diff --git a/src/jni/functions.cpp b/src/jni/functions.cpp index 0f31cd4b3..2bff3ee86 100644 --- a/src/jni/functions.cpp +++ b/src/jni/functions.cpp @@ -131,6 +131,17 @@ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepare(JNI } } +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1pending_1query(JNIEnv * env, jclass param0, jobject param1, jbyteArray param2) { + try { + return _duckdb_jdbc_pending_query(env, param0, param1, param2); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + + return nullptr; + } +} + JNIEXPORT void JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release(JNIEnv * env, jclass param0, jobject param1) { try { return _duckdb_jdbc_release(env, param0, param1); @@ -141,6 +152,16 @@ JNIEXPORT void JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release(JNIEnv } } +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release_1pending(JNIEnv * env, jclass param0, jobject param1) { + try { + _duckdb_jdbc_release_pending(env, param0, param1); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + + } +} + JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1result_1meta(JNIEnv * env, jclass param0, jobject param1) { try { return _duckdb_jdbc_query_result_meta(env, param0, param1); @@ -174,6 +195,17 @@ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute(JNI } } +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute_1pending(JNIEnv * env, jclass param0, jobject param1) { + try { + return _duckdb_jdbc_execute_pending(env, param0, param1); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + + return nullptr; + } +} + JNIEXPORT void JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1free_1result(JNIEnv * env, jclass param0, jobject param1) { try { return _duckdb_jdbc_free_result(env, param0, param1); diff --git a/src/jni/functions.hpp b/src/jni/functions.hpp index 53924f660..e92e92bfc 100644 --- a/src/jni/functions.hpp +++ b/src/jni/functions.hpp @@ -57,10 +57,18 @@ jobject _duckdb_jdbc_prepare(JNIEnv * env, jclass param0, jobject param1, jbyteA JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1prepare(JNIEnv * env, jclass param0, jobject param1, jbyteArray param2); +jobject _duckdb_jdbc_pending_query(JNIEnv * env, jclass param0, jobject param1, jbyteArray param2); + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1pending_1query(JNIEnv * env, jclass param0, jobject param1, jbyteArray param2); + void _duckdb_jdbc_release(JNIEnv * env, jclass param0, jobject param1); JNIEXPORT void JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release(JNIEnv * env, jclass param0, jobject param1); +void _duckdb_jdbc_release_pending(JNIEnv * env, jclass param0, jobject param1); + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1release_1pending(JNIEnv * env, jclass param0, jobject param1); + jobject _duckdb_jdbc_query_result_meta(JNIEnv * env, jclass param0, jobject param1); JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1query_1result_1meta(JNIEnv * env, jclass param0, jobject param1); @@ -73,6 +81,10 @@ jobject _duckdb_jdbc_execute(JNIEnv * env, jclass param0, jobject param1, jobjec JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute(JNIEnv * env, jclass param0, jobject param1, jobjectArray param2); +jobject _duckdb_jdbc_execute_pending(JNIEnv * env, jclass param0, jobject param1); + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1execute_1pending(JNIEnv * env, jclass param0, jobject param1); + void _duckdb_jdbc_free_result(JNIEnv * env, jclass param0, jobject param1); JNIEXPORT void JNICALL Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1free_1result(JNIEnv * env, jclass param0, jobject param1); diff --git a/src/jni/holders.hpp b/src/jni/holders.hpp index d099dabb1..74e0744cf 100644 --- a/src/jni/holders.hpp +++ b/src/jni/holders.hpp @@ -46,6 +46,10 @@ struct StatementHolder { duckdb::unique_ptr stmt; }; +struct PendingHolder { + duckdb::unique_ptr pending; +}; + struct ResultHolder { duckdb::unique_ptr res; duckdb::unique_ptr chunk; diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index 238f8f9f3..280f7d58f 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -37,6 +37,7 @@ public final class DuckDBConnection implements java.sql.Connection { ByteBuffer connRef; final ReentrantLock connRefLock = new ReentrantLock(); + final LinkedHashSet pendingQueries = new LinkedHashSet<>(); final LinkedHashSet preparedStatements = new LinkedHashSet<>(); final LinkedHashSet appenders = new LinkedHashSet<>(); volatile boolean closing; @@ -145,6 +146,14 @@ public void close() throws SQLException { // suppress } + // Last pending query created is first deleted + List pendingList = new ArrayList<>(pendingQueries); + Collections.reverse(pendingList); + for (DuckDBPendingQuery pending : pendingList) { + pending.close(); + } + pendingQueries.clear(); + // Last statement created is first deleted List psList = new ArrayList<>(preparedStatements); Collections.reverse(psList); diff --git a/src/main/java/org/duckdb/DuckDBNative.java b/src/main/java/org/duckdb/DuckDBNative.java index c334526bb..9570c2c4d 100644 --- a/src/main/java/org/duckdb/DuckDBNative.java +++ b/src/main/java/org/duckdb/DuckDBNative.java @@ -173,6 +173,12 @@ private static void loadFromCurrentJarDir(String libName) throws Exception { // returns res_ref result reference object static native ByteBuffer duckdb_jdbc_execute(ByteBuffer stmt_ref, Object[] params) throws SQLException; + static native ByteBuffer duckdb_jdbc_pending_query(ByteBuffer conn_ref, byte[] query) throws SQLException; + + static native ByteBuffer duckdb_jdbc_execute_pending(ByteBuffer pending_ref) throws SQLException; + + static native void duckdb_jdbc_release_pending(ByteBuffer pending_ref) throws SQLException; + static native void duckdb_jdbc_free_result(ByteBuffer res_ref); static native DuckDBVector[] duckdb_jdbc_fetch(ByteBuffer res_ref, ByteBuffer conn_ref) throws SQLException; diff --git a/src/main/java/org/duckdb/DuckDBPendingQuery.java b/src/main/java/org/duckdb/DuckDBPendingQuery.java new file mode 100644 index 000000000..146754b69 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBPendingQuery.java @@ -0,0 +1,50 @@ +package org.duckdb; + +import java.nio.ByteBuffer; +import java.sql.SQLException; +import java.util.concurrent.locks.ReentrantLock; + +class DuckDBPendingQuery { + private DuckDBConnection conn; + ByteBuffer pendingRef = null; + final ReentrantLock pendingRefLock = new ReentrantLock(); + + DuckDBPendingQuery(DuckDBConnection conn, ByteBuffer pendingRef) { + this.conn = conn; + this.pendingRef = pendingRef; + this.conn.connRefLock.lock(); + try { + this.conn.pendingQueries.add(this); + } finally { + this.conn.connRefLock.unlock(); + } + } + + void close() throws SQLException { + if (pendingRef == null) { + return; + } + pendingRefLock.lock(); + try { + if (pendingRef == null) { + return; + } + DuckDBNative.duckdb_jdbc_release_pending(pendingRef); + pendingRef = null; + } finally { + pendingRefLock.unlock(); + } + + // Untrack pending query from parent connection, + // if 'closing' flag is set it means that the parent connection itself + // is being closed and we don't need to untrack this instance + if (!conn.closing) { + conn.connRefLock.lock(); + try { + conn.pendingQueries.remove(this); + } finally { + conn.connRefLock.unlock(); + } + } + } +} diff --git a/src/main/java/org/duckdb/DuckDBPreparedStatement.java b/src/main/java/org/duckdb/DuckDBPreparedStatement.java index dc5e45330..0df54c35e 100644 --- a/src/main/java/org/duckdb/DuckDBPreparedStatement.java +++ b/src/main/java/org/duckdb/DuckDBPreparedStatement.java @@ -47,7 +47,8 @@ public class DuckDBPreparedStatement implements PreparedStatement { private DuckDBConnection conn; private ByteBuffer stmtRef = null; - final ReentrantLock stmtRefLock = new ReentrantLock(); + private final ReentrantLock stmtRefLock = new ReentrantLock(); + private String query = null; volatile boolean closeOnCompletion = false; private DuckDBResultSet selectResult = null; @@ -61,7 +62,7 @@ public class DuckDBPreparedStatement implements PreparedStatement { private final List batchedParams = new ArrayList<>(); private final List batchedStatements = new ArrayList<>(); private Boolean isBatch = false; - private Boolean isPreparedStatement = false; + private final Boolean isPreparedStatement; private int queryTimeoutSeconds = 0; private ScheduledFuture cancelQueryFuture = null; @@ -70,6 +71,7 @@ public DuckDBPreparedStatement(DuckDBConnection conn) throws SQLException { throw new SQLException("connection parameter cannot be null"); } this.conn = conn; + this.isPreparedStatement = false; } public DuckDBPreparedStatement(DuckDBConnection conn, String sql) throws SQLException { @@ -116,6 +118,11 @@ private void prepare(String sql) throws SQLException { throw new SQLException("sql query parameter cannot be null"); } + if (!this.isPreparedStatement) { + this.query = sql; + return; + } + stmtRefLock.lock(); try { checkOpen(); @@ -160,6 +167,33 @@ private void prepare(String sql) throws SQLException { } } + private DirectQueryResult executeDirect() throws SQLException { + DuckDBPendingQuery pending = null; + + // stmtRef lock is being held + conn.connRefLock.lock(); + try { + conn.checkOpen(); + ByteBuffer pendingRef = DuckDBNative.duckdb_jdbc_pending_query(conn.connRef, query.getBytes(UTF_8)); + pending = new DuckDBPendingQuery(conn, pendingRef); + // need to track the statement too to release the results + conn.preparedStatements.add(this); + } finally { + conn.connRefLock.unlock(); + } + + pending.pendingRefLock.lock(); + try { + if (pending.pendingRef == null) { + throw new SQLException("Connection was closed"); + } + ByteBuffer resultRef = DuckDBNative.duckdb_jdbc_execute_pending(pending.pendingRef); + return new DirectQueryResult(resultRef, pending); + } finally { + pending.pendingRefLock.unlock(); + } + } + @Override public boolean execute() throws SQLException { checkOpen(); @@ -171,6 +205,8 @@ public boolean execute() throws SQLException { connLock.unlock(); ByteBuffer resultRef = null; + DuckDBPendingQuery pendingQuery = null; + boolean queryFailed = false; stmtRefLock.lock(); try { @@ -186,19 +222,16 @@ public boolean execute() throws SQLException { startTransaction(); } - if (queryTimeoutSeconds > 0) { - cleanupCancelQueryTask(); - try { - if (!DuckDBDriver.scheduler.isShutdown()) { - cancelQueryFuture = - DuckDBDriver.scheduler.schedule(new CancelQueryTask(), queryTimeoutSeconds, SECONDS); - } - } catch (RejectedExecutionException e) { - // no-op, scheduler was shut down concurrently - } + scheduleCancelTask(); + + if (isPreparedStatement) { + resultRef = DuckDBNative.duckdb_jdbc_execute(stmtRef, params); + } else { + DirectQueryResult dqr = executeDirect(); + resultRef = dqr.resultRef; + pendingQuery = dqr.pendingQuery; } - resultRef = DuckDBNative.duckdb_jdbc_execute(stmtRef, params); cleanupCancelQueryTask(); DuckDBResultSetMetaData resultMeta = DuckDBNative.duckdb_jdbc_query_result_meta(resultRef); selectResult = new DuckDBResultSet(conn, this, resultMeta, resultRef); @@ -207,18 +240,23 @@ public boolean execute() throws SQLException { returnsNothing = resultMeta.return_type.equals(NOTHING); } catch (SQLException e) { - // Delete result set that might have been allocated - if (selectResult != null) { - selectResult.close(); - } else if (resultRef != null) { - DuckDBNative.duckdb_jdbc_free_result(resultRef); - resultRef = null; - } - close(); + queryFailed = true; throw e; } finally { stmtRefLock.unlock(); + this.query = null; + if (null != pendingQuery) { + pendingQuery.close(); + } + if (queryFailed) { + if (null != selectResult) { + selectResult.close(); + } else if (null != resultRef) { + DuckDBNative.duckdb_jdbc_free_result(resultRef); + } + close(); + } } if (returnsChangedRows) { @@ -370,12 +408,17 @@ public void close() throws SQLException { if (isClosed()) { return; } + + DuckDBConnection connLocal = conn; + stmtRefLock.lock(); try { if (isClosed()) { return; } + cleanupCancelQueryTask(); + if (selectResult != null) { selectResult.close(); selectResult = null; @@ -383,25 +426,25 @@ public void close() throws SQLException { if (stmtRef != null) { // Delete prepared statement DuckDBNative.duckdb_jdbc_release(stmtRef); - - // Untrack prepared statement from parent connection, - // if 'closing' flag is set it means that the parent connection itself - // is being closed and we don't need to untrack this instance from the statement. - if (!conn.closing) { - conn.connRefLock.lock(); - try { - conn.preparedStatements.remove(this); - } finally { - conn.connRefLock.unlock(); - } - } - stmtRef = null; } + conn = null; // we use this as a check for closed-ness } finally { stmtRefLock.unlock(); } + + // Untrack prepared statement from parent connection, + // if 'closing' flag is set it means that the parent connection itself + // is being closed and we don't need to untrack this instance + if (!connLocal.closing) { + connLocal.connRefLock.lock(); + try { + connLocal.preparedStatements.remove(this); + } finally { + connLocal.connRefLock.unlock(); + } + } } @Override @@ -526,9 +569,6 @@ public ResultSet getResultSet() throws SQLException { if (isClosed()) { throw new SQLException("Statement was closed"); } - if (stmtRef == null) { - throw new SQLException("Prepare something first"); - } if (!returnsResultSet) { return null; @@ -544,7 +584,7 @@ private long getUpdateCountInternal() throws SQLException { if (isClosed()) { throw new SQLException("Statement was closed"); } - if (stmtRef == null) { + if (selectResult == null) { // It is not required by JDBC spec to return anything in this case, // but clients can call this method before preparing/executing the query return -1; @@ -1236,8 +1276,12 @@ private void checkOpen() throws SQLException { } private void checkPrepared() throws SQLException { - if (stmtRef == null) { - throw new SQLException("Prepare something first"); + if (isPreparedStatement) { + if (stmtRef == null) { + throw new SQLException("Prepare something first"); + } + } else if (query == null) { + throw new SQLException("Query to execute was not specified"); } } @@ -1289,6 +1333,19 @@ private void cleanupCancelQueryTask() { } } + private void scheduleCancelTask() { + if (queryTimeoutSeconds <= 0 || DuckDBDriver.scheduler.isShutdown()) { + return; + } + cleanupCancelQueryTask(); + try { + this.cancelQueryFuture = + DuckDBDriver.scheduler.schedule(new CancelQueryTask(), queryTimeoutSeconds, SECONDS); + } catch (RejectedExecutionException e) { + // no-op, scheduler was shut down concurrently + } + } + private class CancelQueryTask implements Runnable { @Override public void run() { @@ -1302,4 +1359,14 @@ public void run() { } } } + + private static class DirectQueryResult { + final ByteBuffer resultRef; + final DuckDBPendingQuery pendingQuery; + + private DirectQueryResult(ByteBuffer resultRef, DuckDBPendingQuery pendingQuery) { + this.resultRef = resultRef; + this.pendingQuery = pendingQuery; + } + } } diff --git a/src/test/java/org/duckdb/TestClosure.java b/src/test/java/org/duckdb/TestClosure.java index c701c7afe..af6fd83df 100644 --- a/src/test/java/org/duckdb/TestClosure.java +++ b/src/test/java/org/duckdb/TestClosure.java @@ -9,6 +9,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.ThreadLocalRandom; public class TestClosure { @@ -27,6 +28,20 @@ public static void test_unclosed_statement_does_not_hang() throws Exception { assertTrue(new File(dbName).delete()); } + public static void test_unclosed_prepred_statement_does_not_hang() throws Exception { + String dbName = "test_issue_101.db"; + String url = JDBC_URL + dbName; + Connection conn = DriverManager.getConnection(url); + PreparedStatement ps = conn.prepareStatement("select 42"); + ps.execute(); + // statement not closed explicitly + conn.close(); + assertTrue(ps.isClosed()); + Connection connOther = DriverManager.getConnection(url); + connOther.close(); + assertTrue(new File(dbName).delete()); + } + public static void test_result_set_auto_closed() throws Exception { try (Connection conn = DriverManager.getConnection(JDBC_URL)) { Statement stmt = conn.createStatement(); @@ -38,6 +53,17 @@ public static void test_result_set_auto_closed() throws Exception { } } + public static void test_result_set_auto_closed_prepared() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + PreparedStatement ps = conn.prepareStatement("select 42"); + ResultSet rs1 = ps.executeQuery(); + ResultSet rs2 = ps.executeQuery(); + assertTrue(rs1.isClosed()); + ps.close(); + assertTrue(rs2.isClosed()); + } + } + public static void test_statements_auto_closed_on_conn_close() throws Exception { Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt1 = conn.createStatement(); @@ -73,6 +99,16 @@ public static void test_results_auto_closed_on_conn_close() throws Exception { assertTrue(stmt.isClosed()); } + public static void test_results_auto_closed_on_conn_close_prepared() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + PreparedStatement ps = conn.prepareStatement("select 42"); + ResultSet rs = ps.executeQuery(); + rs.next(); + conn.close(); + assertTrue(rs.isClosed()); + assertTrue(ps.isClosed()); + } + public static void test_statement_auto_closed_on_completion() throws Exception { try (Connection conn = DriverManager.getConnection(JDBC_URL)) { Statement stmt = conn.createStatement(); @@ -85,6 +121,18 @@ public static void test_statement_auto_closed_on_completion() throws Exception { } } + public static void test_prepared_statement_auto_closed_on_completion() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + PreparedStatement ps = conn.prepareStatement("select 42"); + ps.closeOnCompletion(); + assertTrue(ps.isCloseOnCompletion()); + try (ResultSet rs = ps.executeQuery()) { + rs.next(); + } + assertTrue(ps.isClosed()); + } + } + public static void test_long_query_conn_close() throws Exception { Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); @@ -115,6 +163,34 @@ public static void test_long_query_conn_close() throws Exception { assertTrue(conn.isClosed()); } + public static void test_long_query_conn_close_prepared() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt = conn.createStatement(); + stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + PreparedStatement ps = conn.prepareStatement( + "WITH RECURSIVE cte AS (" + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) " + + "SELECT avg(f) FROM cte"); + long start = System.currentTimeMillis(); + Thread th = new Thread(() -> { + try { + Thread.sleep(1000); + conn.close(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + th.start(); + assertThrows(ps::executeQuery, SQLException.class); + th.join(); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed < 2000); + assertTrue(stmt.isClosed()); + assertTrue(ps.isClosed()); + assertTrue(conn.isClosed()); + } + public static void test_long_query_stmt_close() throws Exception { try (Connection conn = DriverManager.getConnection(JDBC_URL)) { Statement stmt = conn.createStatement(); @@ -147,6 +223,37 @@ public static void test_long_query_stmt_close() throws Exception { } } + public static void test_long_query_prepared_stmt_close() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();) { + stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + PreparedStatement ps = conn.prepareStatement( + "WITH RECURSIVE cte AS (" + + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) " + + "SELECT avg(f) FROM cte" + + ); + long start = System.currentTimeMillis(); + Thread th = new Thread(() -> { + try { + Thread.sleep(1000); + ps.cancel(); + ps.close(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + th.start(); + assertThrows(ps::executeQuery, SQLException.class); + th.join(); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed < 2000); + assertTrue(ps.isClosed()); + assertFalse(conn.isClosed()); + } + } + public static void test_conn_close_no_crash() throws Exception { ExecutorService executor = Executors.newSingleThreadExecutor(); for (int i = 0; i < 1 << 7; i++) { @@ -154,14 +261,52 @@ public static void test_conn_close_no_crash() throws Exception { Statement stmt = conn.createStatement(); Future future = executor.submit(() -> { try { + long millis = ThreadLocalRandom.current().nextInt(1, 20); + Thread.sleep(millis); + conn.close(); + } catch (SQLException e) { + fail(); + } catch (InterruptedException e) { + throw new RuntimeException(); + } + }); + try { + long millis = ThreadLocalRandom.current().nextInt(1, 20); + Thread.sleep(millis); + stmt.executeQuery("SELECT 42"); + } catch (SQLException e) { + // suppress + } finally { + stmt.close(); + } + future.get(); + } + } + + public static void test_conn_close_no_crash_prepared() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + for (int i = 0; i < 1 << 7; i++) { + Connection conn = DriverManager.getConnection(JDBC_URL); + PreparedStatement ps = conn.prepareStatement("SELECT 42"); + Future future = executor.submit(() -> { + try { + long millis = ThreadLocalRandom.current().nextInt(1, 20); + Thread.sleep(millis); conn.close(); } catch (SQLException e) { fail(); + } catch (InterruptedException e) { + throw new RuntimeException(); } }); try { - stmt.executeQuery("select 42"); + long millis = ThreadLocalRandom.current().nextInt(1, 20); + Thread.sleep(millis); + ps.executeQuery(); } catch (SQLException e) { + // suppress + } finally { + ps.close(); } future.get(); } @@ -188,6 +333,27 @@ public static void test_stmt_close_no_crash() throws Exception { } } + public static void test_prepared_stmt_close_no_crash() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + for (int i = 0; i < 1 << 10; i++) { + PreparedStatement ps = conn.prepareStatement("select 42"); + Future future = executor.submit(() -> { + try { + ps.close(); + } catch (SQLException e) { + fail(); + } + }); + try { + ps.executeQuery(); + } catch (SQLException e) { + } + future.get(); + } + } + } + public static void test_results_close_no_crash() throws Exception { ExecutorService executor = Executors.newSingleThreadExecutor(); try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { @@ -262,6 +428,38 @@ public static void test_results_fetch_no_hang() throws Exception { } } + @SuppressWarnings("try") + public static void test_results_fetch_no_hang_prepared() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + Properties config = new Properties(); + config.put(DuckDBDriver.JDBC_STREAM_RESULTS, true); + long rowsCount = 1 << 24; + int iterations = 1; + for (int i = 0; i < iterations; i++) { + try (Connection conn = DriverManager.getConnection(JDBC_URL, config); + PreparedStatement ps = + conn.prepareStatement("SELECT i, i::VARCHAR FROM range(0, " + rowsCount + ") AS t(i)"); + ResultSet rs = ps.executeQuery()) { + executor.submit(() -> { + try { + Thread.sleep(100); + conn.close(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + long[] resultsCount = new long[1]; + assertThrows(() -> { + while (rs.next()) { + resultsCount[0]++; + } + }, SQLException.class); + assertTrue(resultsCount[0] > 0); + assertTrue(resultsCount[0] < rowsCount); + } + } + } + public static void test_stmt_can_only_cancel_self() throws Exception { try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt1 = conn.createStatement(); Statement stmt2 = conn.createStatement()) { @@ -296,6 +494,41 @@ public static void test_stmt_can_only_cancel_self() throws Exception { } } + public static void test_prepared_stmt_can_only_cancel_self() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt1 = conn.createStatement();) { + stmt1.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt1.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + try ( + PreparedStatement ps1 = conn.prepareStatement("SELECT 42"); + PreparedStatement ps2 = conn.prepareStatement( + "WITH RECURSIVE cte AS (" + + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 50000) " + + "SELECT avg(f) FROM cte")) { + long start = System.currentTimeMillis(); + Thread th = new Thread(() -> { + try { + Thread.sleep(200); + ps1.cancel(); + } catch (Exception e) { + e.printStackTrace(); + } + }); + th.start(); + try (ResultSet rs = ps2.executeQuery()) { + rs.next(); + assertTrue(rs.getDouble(1) > 0); + } + th.join(); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed > 1000); + assertFalse(conn.isClosed()); + assertFalse(ps1.isClosed()); + assertFalse(ps2.isClosed()); + } + } + } + public static void test_stmt_query_timeout() throws Exception { try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { stmt.setQueryTimeout(1); @@ -323,6 +556,33 @@ public static void test_stmt_query_timeout() throws Exception { } } + public static void test_prepared_stmt_query_timeout() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement();) { + stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + try ( + PreparedStatement ps = conn.prepareStatement( + "WITH RECURSIVE cte AS (" + + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) " + + "SELECT avg(f) FROM cte")) { + ps.setQueryTimeout(1); + long start = System.currentTimeMillis(); + assertThrows(ps::executeQuery, SQLTimeoutException.class); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed < 1500); + assertFalse(conn.isClosed()); + assertTrue(ps.isClosed()); + assertEquals(DuckDBDriver.scheduler.getQueue().size(), 0); + } + } + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + stmt.setQueryTimeout(1); + assertThrows(() -> { stmt.execute("FAIL"); }, SQLException.class); + assertEquals(DuckDBDriver.scheduler.getQueue().size(), 0); + } + } + public static void manual_test_set_query_timeout_wo_scheduler() throws Exception { assertTrue(DuckDBDriver.shutdownQueryCancelScheduler()); assertFalse(DuckDBDriver.shutdownQueryCancelScheduler());