diff --git a/src/main/java/org/duckdb/DuckDBPreparedStatement.java b/src/main/java/org/duckdb/DuckDBPreparedStatement.java index 69f154848..95b3e4b8f 100644 --- a/src/main/java/org/duckdb/DuckDBPreparedStatement.java +++ b/src/main/java/org/duckdb/DuckDBPreparedStatement.java @@ -645,11 +645,13 @@ public long[] executeLargeBatch() throws SQLException { private long[] executeBatchedPreparedStatement() throws SQLException { stmtRefLock.lock(); + boolean tranStarted = false; + DuckDBConnection conn = this.conn; try { checkOpen(); checkPrepared(); - boolean tranStarted = startTransaction(); + tranStarted = startTransaction(); long[] updateCounts = new long[this.batchedParams.size()]; for (int i = 0; i < this.batchedParams.size(); i++) { @@ -664,6 +666,12 @@ private long[] executeBatchedPreparedStatement() throws SQLException { } return updateCounts; + + } catch (SQLException e) { + if (tranStarted && conn.getAutoCommit()) { + conn.rollback(); + } + throw e; } finally { stmtRefLock.unlock(); } @@ -671,10 +679,12 @@ private long[] executeBatchedPreparedStatement() throws SQLException { private long[] executeBatchedStatements() throws SQLException { stmtRefLock.lock(); + boolean tranStarted = false; + DuckDBConnection conn = this.conn; try { checkOpen(); - boolean tranStarted = startTransaction(); + tranStarted = startTransaction(); long[] updateCounts = new long[this.batchedStatements.size()]; for (int i = 0; i < this.batchedStatements.size(); i++) { @@ -689,6 +699,12 @@ private long[] executeBatchedStatements() throws SQLException { } return updateCounts; + + } catch (SQLException e) { + if (tranStarted && conn.getAutoCommit()) { + conn.rollback(); + } + throw e; } finally { stmtRefLock.unlock(); } diff --git a/src/test/java/org/duckdb/TestBatch.java b/src/test/java/org/duckdb/TestBatch.java index ed57e612d..51e1d781e 100644 --- a/src/test/java/org/duckdb/TestBatch.java +++ b/src/test/java/org/duckdb/TestBatch.java @@ -4,6 +4,7 @@ import static org.duckdb.test.Assertions.*; import java.sql.*; +import java.util.Properties; public class TestBatch { @@ -206,4 +207,102 @@ public static void test_statement_batch_rollback() throws Exception { } } } + + public static void test_statement_batch_autocommit_constraint_violation() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + assertTrue(conn.getAutoCommit()); + try (Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1 (col1 VARCHAR NOT NULL)"); + } + try (Statement stmt = conn.createStatement()) { + stmt.addBatch("INSERT INTO tab1 VALUES('foo')"); + stmt.addBatch("INSERT INTO tab1 VALUES(NULL)"); + assertThrows(stmt::executeBatch, SQLException.class); + } + try (Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + rs.next(); + assertEquals(rs.getLong(1), 0L); + } + } + } + + public static void test_prepared_statement_batch_autocommit_constraint_violation() throws Exception { + try (Connection conn = DriverManager.getConnection(JDBC_URL)) { + assertTrue(conn.getAutoCommit()); + try (Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1 (col1 VARCHAR NOT NULL)"); + } + try (PreparedStatement ps = conn.prepareStatement("INSERT INTO tab1 VALUES(?)")) { + ps.setString(1, "foo"); + ps.addBatch(); + ps.setString(1, null); + ps.addBatch(); + assertThrows(ps::executeBatch, SQLException.class); + } + try (Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + rs.next(); + assertEquals(rs.getLong(1), 0L); + } + } + } + + public static void test_statement_batch_constraint_violation() throws Exception { + Properties config = new Properties(); + config.put(DuckDBDriver.JDBC_AUTO_COMMIT, false); + try (Connection conn = DriverManager.getConnection(JDBC_URL, config)) { + assertFalse(conn.getAutoCommit()); + try (Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1 (col1 VARCHAR NOT NULL)"); + conn.commit(); + } + boolean thrown = false; + try (Statement stmt = conn.createStatement()) { + stmt.addBatch("INSERT INTO tab1 VALUES('foo')"); + stmt.addBatch("INSERT INTO tab1 VALUES(NULL)"); + stmt.executeBatch(); + conn.commit(); + } catch (SQLException e) { + thrown = true; + conn.rollback(); + } + assertTrue(thrown); + try (Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + rs.next(); + assertEquals(rs.getLong(1), 0L); + } + } + } + + public static void test_prepared_statement_batch_constraint_violation() throws Exception { + Properties config = new Properties(); + config.put(DuckDBDriver.JDBC_AUTO_COMMIT, false); + try (Connection conn = DriverManager.getConnection(JDBC_URL, config)) { + assertFalse(conn.getAutoCommit()); + try (Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1 (col1 VARCHAR NOT NULL)"); + conn.commit(); + } + boolean thrown = false; + try (PreparedStatement ps = conn.prepareStatement("INSERT INTO tab1 VALUES(?)")) { + ps.setString(1, "foo"); + ps.addBatch(); + ps.setString(1, null); + ps.addBatch(); + ps.executeBatch(); + conn.commit(); + } catch (SQLException e) { + thrown = true; + conn.rollback(); + } + assertTrue(thrown); + try (Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + rs.next(); + assertEquals(rs.getLong(1), 0L); + } + } + } }