diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index 38c01d2fb..50d62d32c 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -40,8 +40,8 @@ public final class DuckDBConnection implements java.sql.Connection { final LinkedHashSet preparedStatements = new LinkedHashSet<>(); volatile boolean closing = false; - boolean autoCommit = true; - boolean transactionRunning; + volatile boolean autoCommit = true; + volatile boolean transactionRunning; final String url; private final boolean readOnly; diff --git a/src/main/java/org/duckdb/DuckDBPreparedStatement.java b/src/main/java/org/duckdb/DuckDBPreparedStatement.java index 12a2095cf..3e1f61144 100644 --- a/src/main/java/org/duckdb/DuckDBPreparedStatement.java +++ b/src/main/java/org/duckdb/DuckDBPreparedStatement.java @@ -79,16 +79,26 @@ public DuckDBPreparedStatement(DuckDBConnection conn, String sql) throws SQLExce prepare(sql); } - private void startTransaction() throws SQLException { + private boolean isConnAutoCommit() throws SQLException { checkOpen(); try { - if (this.conn.autoCommit || this.conn.transactionRunning) { - return; + return this.conn.autoCommit; + } catch (NullPointerException e) { + throw new SQLException(e); + } + } + + private boolean startTransaction() throws SQLException { + checkOpen(); + try { + if (this.conn.transactionRunning) { + return false; } this.conn.transactionRunning = true; // Start transaction via Statement try (Statement s = conn.createStatement()) { s.execute("BEGIN TRANSACTION;"); + return true; } } catch (NullPointerException e) { throw new SQLException(e); @@ -161,7 +171,7 @@ private boolean execute(boolean startTransaction) throws SQLException { } selectResult = null; - if (startTransaction) { + if (startTransaction && !isConnAutoCommit()) { startTransaction(); } @@ -576,41 +586,62 @@ public int[] executeBatch() throws SQLException { @Override public long[] executeLargeBatch() throws SQLException { checkOpen(); - try { - if (this.isPreparedStatement) { - return executeBatchedPreparedStatement(); - } else { - return executeBatchedStatements(); - } - } finally { - if (!isClosed()) { - clearBatch(); - } + if (this.isPreparedStatement) { + return executeBatchedPreparedStatement(); + } else { + return executeBatchedStatements(); } } private long[] executeBatchedPreparedStatement() throws SQLException { - long[] updateCounts = new long[this.batchedParams.size()]; + stmtRefLock.lock(); + try { + checkOpen(); + checkPrepared(); + + boolean tranStarted = startTransaction(); - startTransaction(); - for (int i = 0; i < this.batchedParams.size(); i++) { - params = this.batchedParams.get(i); - execute(false); - updateCounts[i] = getUpdateCountInternal(); + long[] updateCounts = new long[this.batchedParams.size()]; + for (int i = 0; i < this.batchedParams.size(); i++) { + params = this.batchedParams.get(i); + execute(false); + updateCounts[i] = getUpdateCountInternal(); + } + clearBatch(); + + if (tranStarted && isConnAutoCommit()) { + this.conn.commit(); + } + + return updateCounts; + } finally { + stmtRefLock.unlock(); } - return updateCounts; } private long[] executeBatchedStatements() throws SQLException { - long[] updateCounts = new long[this.batchedStatements.size()]; + stmtRefLock.lock(); + try { + checkOpen(); + + boolean tranStarted = startTransaction(); + + long[] updateCounts = new long[this.batchedStatements.size()]; + for (int i = 0; i < this.batchedStatements.size(); i++) { + prepare(this.batchedStatements.get(i)); + execute(false); + updateCounts[i] = getUpdateCountInternal(); + } + clearBatch(); - startTransaction(); - for (int i = 0; i < this.batchedStatements.size(); i++) { - prepare(this.batchedStatements.get(i)); - execute(false); - updateCounts[i] = getUpdateCountInternal(); + if (tranStarted && isConnAutoCommit()) { + this.conn.commit(); + } + + return updateCounts; + } finally { + stmtRefLock.unlock(); } - return updateCounts; } @Override diff --git a/src/test/java/org/duckdb/TestBatch.java b/src/test/java/org/duckdb/TestBatch.java index 9f7bb2611..ed57e612d 100644 --- a/src/test/java/org/duckdb/TestBatch.java +++ b/src/test/java/org/duckdb/TestBatch.java @@ -124,4 +124,86 @@ public static void test_prepared_statement_batch_exception() throws Exception { } } } + + public static void test_prepared_statement_batch_autocommit() throws Exception { + long count = 1 << 10; + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + assertTrue(conn.getAutoCommit()); + try (Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1 (col1 BIGINT, col2 VARCHAR)"); + } + try (PreparedStatement ps = conn.prepareStatement("INSERT INTO tab1 VALUES(?, ?)")) { + for (long i = 0; i < count; i++) { + ps.setLong(1, i); + ps.setString(2, i + "foo"); + ps.addBatch(); + } + ps.executeBatch(); + } + try (Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + rs.next(); + assertEquals(rs.getLong(1), count); + } + } + } + + public static void test_statement_batch_autocommit() throws Exception { + long count = 1 << 10; + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + assertTrue(conn.getAutoCommit()); + stmt.execute("CREATE TABLE tab1 (col1 BIGINT, col2 VARCHAR)"); + for (long i = 0; i < count; i++) { + stmt.addBatch("INSERT INTO tab1 VALUES(" + i + ", '" + i + "foo')"); + } + stmt.executeBatch(); + try (ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + rs.next(); + assertEquals(rs.getLong(1), count); + } + } + } + + public static void test_prepared_statement_batch_rollback() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + try (Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1 (col1 BIGINT, col2 VARCHAR)"); + } + conn.setAutoCommit(false); + try (Statement stmt = conn.createStatement()) { + stmt.execute("INSERT INTO tab1 VALUES(-1, 'bar')"); + } + try (PreparedStatement ps = conn.prepareStatement("INSERT INTO tab1 VALUES(?, ?)")) { + for (long i = 0; i < 1 << 10; i++) { + ps.setLong(1, i); + ps.setString(2, i + "foo"); + ps.addBatch(); + } + ps.executeBatch(); + } + conn.rollback(); + try (Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + rs.next(); + assertEquals(rs.getLong(1), 0L); + } + } + } + + public static void test_statement_batch_rollback() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1 (col1 BIGINT, col2 VARCHAR)"); + conn.setAutoCommit(false); + stmt.execute("INSERT INTO tab1 VALUES(-1, 'bar')"); + for (long i = 0; i < 1 << 10; i++) { + stmt.addBatch("INSERT INTO tab1 VALUES(" + i + ", '" + i + "foo')"); + } + stmt.executeBatch(); + conn.rollback(); + try (ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + rs.next(); + assertEquals(rs.getLong(1), 0L); + } + } + } }