From e860720568a55dcaaf3378b28b1a1a83d14aa609 Mon Sep 17 00:00:00 2001 From: Alex Kasko Date: Tue, 3 Jun 2025 13:01:13 +0100 Subject: [PATCH] Add support for query timeouts (1.3) This is a backport of the PR #247 to `v1.3-ossivalis` stable branch. This change implements `Statement#setQueryTimeout()` method. It is implemented by scheduling a background task and calling `Statement#cancel()` when timeout expires. Timeouted statement has the same behaviour as it would be if cancelled manually - `SQLException` is thrown and the statement is closed. Timeout is applied for all `execute*` calls. For `executeBatch()` it is applied separately for every single query in a batch. Testing: new test added. Fixes: #212 --- src/jni/duckdb_java.cpp | 6 ++- src/jni/refs.cpp | 2 + src/jni/refs.hpp | 1 + src/main/java/org/duckdb/DuckDBDriver.java | 9 ++++- .../org/duckdb/DuckDBPreparedStatement.java | 39 ++++++++++++++++++- src/test/java/org/duckdb/TestClosure.java | 31 +++++++++++++-- src/test/java/org/duckdb/TestDuckDBJDBC.java | 2 +- 7 files changed, 82 insertions(+), 8 deletions(-) diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index c4362883d..dd2364fcf 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -240,9 +240,11 @@ jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectA res_ref->res = stmt_ref->stmt->Execute(duckdb_params, stream_results); if (res_ref->res->HasError()) { - string error_msg = string(res_ref->res->GetError()); + std::string error_msg = std::string(res_ref->res->GetError()); + duckdb::ExceptionType error_type = res_ref->res->GetErrorType(); res_ref->res = nullptr; - ThrowJNI(env, error_msg.c_str()); + jclass exc_type = duckdb::ExceptionType::INTERRUPT == error_type ? J_SQLTimeoutException : J_SQLException; + env->ThrowNew(exc_type, error_msg.c_str()); return nullptr; } return env->NewDirectByteBuffer(res_ref.release(), 0); diff --git a/src/jni/refs.cpp b/src/jni/refs.cpp index 4c05515e4..f6ae8146d 100644 --- a/src/jni/refs.cpp +++ b/src/jni/refs.cpp @@ -18,6 +18,7 @@ jmethodID J_String_getBytes; jclass J_Throwable; jmethodID J_Throwable_getMessage; jclass J_SQLException; +jclass J_SQLTimeoutException; jclass J_Bool; jclass J_Byte; @@ -178,6 +179,7 @@ void create_refs(JNIEnv *env) { J_Throwable = make_class_ref(env, "java/lang/Throwable"); J_Throwable_getMessage = get_method_id(env, J_Throwable, "getMessage", "()Ljava/lang/String;"); J_SQLException = make_class_ref(env, "java/sql/SQLException"); + J_SQLTimeoutException = make_class_ref(env, "java/sql/SQLTimeoutException"); J_Bool = make_class_ref(env, "java/lang/Boolean"); J_Byte = make_class_ref(env, "java/lang/Byte"); diff --git a/src/jni/refs.hpp b/src/jni/refs.hpp index bd17ee282..5bfeb2639 100644 --- a/src/jni/refs.hpp +++ b/src/jni/refs.hpp @@ -15,6 +15,7 @@ extern jmethodID J_String_getBytes; extern jclass J_Throwable; extern jmethodID J_Throwable_getMessage; extern jclass J_SQLException; +extern jclass J_SQLTimeoutException; extern jclass J_Bool; extern jclass J_Byte; diff --git a/src/main/java/org/duckdb/DuckDBDriver.java b/src/main/java/org/duckdb/DuckDBDriver.java index 694c050a7..a027bc119 100644 --- a/src/main/java/org/duckdb/DuckDBDriver.java +++ b/src/main/java/org/duckdb/DuckDBDriver.java @@ -6,6 +6,8 @@ import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.util.Properties; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; import java.util.logging.Logger; public class DuckDBDriver implements java.sql.Driver { @@ -14,11 +16,16 @@ public class DuckDBDriver implements java.sql.Driver { public static final String DUCKDB_USER_AGENT_PROPERTY = "custom_user_agent"; public static final String JDBC_STREAM_RESULTS = "jdbc_stream_results"; + static final ScheduledThreadPoolExecutor scheduler; + static { try { DriverManager.registerDriver(new DuckDBDriver()); + ThreadFactory tf = r -> new Thread(r, "duckdb-query-cancel-scheduler-thread"); + scheduler = new ScheduledThreadPoolExecutor(1, tf); + scheduler.setRemoveOnCancelPolicy(true); } catch (SQLException e) { - e.printStackTrace(); + throw new RuntimeException(e); } } diff --git a/src/main/java/org/duckdb/DuckDBPreparedStatement.java b/src/main/java/org/duckdb/DuckDBPreparedStatement.java index eff4ca980..69f154848 100644 --- a/src/main/java/org/duckdb/DuckDBPreparedStatement.java +++ b/src/main/java/org/duckdb/DuckDBPreparedStatement.java @@ -2,6 +2,7 @@ import static java.nio.charset.StandardCharsets.US_ASCII; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.duckdb.StatementReturnType.*; import static org.duckdb.io.IOUtils.*; @@ -37,6 +38,7 @@ import java.util.ArrayList; import java.util.Calendar; import java.util.List; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -59,6 +61,8 @@ public class DuckDBPreparedStatement implements PreparedStatement { private final List batchedStatements = new ArrayList<>(); private Boolean isBatch = false; private Boolean isPreparedStatement = false; + private int queryTimeoutSeconds = 0; + private ScheduledFuture cancelQueryFuture = null; public DuckDBPreparedStatement(DuckDBConnection conn) throws SQLException { if (conn == null) { @@ -180,7 +184,14 @@ private boolean execute(boolean startTransaction) throws SQLException { startTransaction(); } + if (queryTimeoutSeconds > 0) { + cleanupCancelQueryTask(); + cancelQueryFuture = + DuckDBDriver.scheduler.schedule(new CancelQueryTask(), queryTimeoutSeconds, SECONDS); + } + resultRef = DuckDBNative.duckdb_jdbc_execute(stmtRef, params); + cleanupCancelQueryTask(); DuckDBResultSetMetaData resultMeta = DuckDBNative.duckdb_jdbc_query_result_meta(resultRef); selectResult = new DuckDBResultSet(conn, this, resultMeta, resultRef); returnsResultSet = resultMeta.return_type.equals(QUERY_RESULT); @@ -356,6 +367,7 @@ public void close() throws SQLException { if (isClosed()) { return; } + cleanupCancelQueryTask(); if (selectResult != null) { selectResult.close(); selectResult = null; @@ -436,12 +448,16 @@ public void setEscapeProcessing(boolean enable) throws SQLException { @Override public int getQueryTimeout() throws SQLException { checkOpen(); - return 0; + return queryTimeoutSeconds; } @Override public void setQueryTimeout(int seconds) throws SQLException { checkOpen(); + if (seconds < 0) { + throw new SQLException("Invalid negative timeout value: " + seconds); + } + this.queryTimeoutSeconds = seconds; } @Override @@ -1244,4 +1260,25 @@ private Lock getConnRefLock() throws SQLException { throw new SQLException(e); } } + + private void cleanupCancelQueryTask() { + if (cancelQueryFuture != null) { + cancelQueryFuture.cancel(false); + cancelQueryFuture = null; + } + } + + private class CancelQueryTask implements Runnable { + @Override + public void run() { + try { + if (DuckDBPreparedStatement.this.isClosed()) { + return; + } + DuckDBPreparedStatement.this.cancel(); + } catch (SQLException e) { + // suppress + } + } + } } diff --git a/src/test/java/org/duckdb/TestClosure.java b/src/test/java/org/duckdb/TestClosure.java index fa25b414b..558f81ce1 100644 --- a/src/test/java/org/duckdb/TestClosure.java +++ b/src/test/java/org/duckdb/TestClosure.java @@ -77,7 +77,6 @@ public static void test_statement_auto_closed_on_completion() throws Exception { public static void test_long_query_conn_close() throws Exception { Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); - stmt.execute("DROP TABLE IF EXISTS test_fib1"); stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); long start = System.currentTimeMillis(); @@ -108,7 +107,6 @@ public static void test_long_query_conn_close() throws Exception { public static void test_long_query_stmt_close() throws Exception { try (Connection conn = DriverManager.getConnection(JDBC_URL)) { Statement stmt = conn.createStatement(); - stmt.execute("DROP TABLE IF EXISTS test_fib1"); stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); long start = System.currentTimeMillis(); @@ -272,7 +270,7 @@ public static void test_stmt_can_only_cancel_self() throws Exception { ResultSet rs = stmt2.executeQuery( "WITH RECURSIVE cte AS (" + - "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 40000) " + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 50000) " + "SELECT avg(f) FROM cte")) { rs.next(); assertTrue(rs.getDouble(1) > 0); @@ -285,4 +283,31 @@ public static void test_stmt_can_only_cancel_self() throws Exception { assertFalse(stmt2.isClosed()); } } + + public static void test_stmt_query_timeout() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + stmt.setQueryTimeout(1); + stmt.execute("CREATE TABLE test_fib1(i bigint, p double, f double)"); + stmt.execute("INSERT INTO test_fib1 values(1, 0, 1)"); + long start = System.currentTimeMillis(); + assertThrows( + () + -> stmt.executeQuery( + "WITH RECURSIVE cte AS (" + + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 150000) " + + "SELECT avg(f) FROM cte"), + SQLTimeoutException.class); + long elapsed = System.currentTimeMillis() - start; + assertTrue(elapsed < 1500); + assertFalse(conn.isClosed()); + assertTrue(stmt.isClosed()); + assertEquals(DuckDBDriver.scheduler.getQueue().size(), 0); + } + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + stmt.setQueryTimeout(1); + assertThrows(() -> { stmt.execute("FAIL"); }, SQLException.class); + assertEquals(DuckDBDriver.scheduler.getQueue().size(), 0); + } + } } diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 8b0fa538f..08908cb79 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -3455,7 +3455,7 @@ public static void test_query_progress() throws Exception { @Override public QueryProgress call() throws Exception { try { - Thread.sleep(1500); + Thread.sleep(2500); QueryProgress qp = stmt.getQueryProgress(); stmt.cancel(); return qp;