diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index 7c0c23dc0..4cc3e81be 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -43,21 +43,29 @@ public final class DuckDBConnection implements java.sql.Connection { volatile boolean transactionRunning; final String url; private final boolean readOnly; + private final String sessionInitSQL; - public static DuckDBConnection newConnection(String url, boolean readOnly, Properties properties) - throws SQLException { + public static DuckDBConnection newConnection(String url, boolean readOnly, Properties properties) throws Exception { + return newConnection(url, readOnly, null, properties); + } + + public static DuckDBConnection newConnection(String url, boolean readOnly, String sessionInitSQL, + Properties properties) throws SQLException { if (null == properties) { properties = new Properties(); } String dbName = dbNameFromUrl(url); ByteBuffer nativeReference = DuckDBNative.duckdb_jdbc_startup(dbName.getBytes(UTF_8), readOnly, properties); - return new DuckDBConnection(nativeReference, url, readOnly); + return new DuckDBConnection(nativeReference, url, readOnly, sessionInitSQL); } - private DuckDBConnection(ByteBuffer connectionReference, String url, boolean readOnly) throws SQLException { + private DuckDBConnection(ByteBuffer connectionReference, String url, boolean readOnly, String sessionInitSQL) + throws SQLException { this.connRef = connectionReference; this.url = url; this.readOnly = readOnly; + this.sessionInitSQL = sessionInitSQL; + // Hardcoded 'true' here is intentional, autocommit is handled in stmt#execute() DuckDBNative.duckdb_jdbc_set_auto_commit(connectionReference, true); } @@ -88,7 +96,7 @@ public Connection duplicate() throws SQLException { connRefLock.lock(); try { checkOpen(); - return new DuckDBConnection(DuckDBNative.duckdb_jdbc_connect(connRef), url, readOnly); + return new DuckDBConnection(DuckDBNative.duckdb_jdbc_connect(connRef), url, readOnly, null); } finally { connRefLock.unlock(); } @@ -471,6 +479,10 @@ public DuckDBHugeInt createHugeInt(long lower, long upper) throws SQLException { return new DuckDBHugeInt(lower, upper); } + public String getSessionInitSQL() throws SQLException { + return sessionInitSQL; + } + void checkOpen() throws SQLException { if (isClosed()) { throw new SQLException("Connection was closed"); diff --git a/src/main/java/org/duckdb/DuckDBDriver.java b/src/main/java/org/duckdb/DuckDBDriver.java index 616629736..dc6d04b69 100644 --- a/src/main/java/org/duckdb/DuckDBDriver.java +++ b/src/main/java/org/duckdb/DuckDBDriver.java @@ -1,14 +1,26 @@ package org.duckdb; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.nio.file.StandardOpenOption.READ; +import static org.duckdb.JdbcUtils.*; +import static org.duckdb.io.IOUtils.readToString; + +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.DigestInputStream; +import java.security.MessageDigest; import java.sql.*; import java.util.*; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; -import static org.duckdb.JdbcUtils.*; - import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Logger; +import org.duckdb.io.LimitedInputStream; public class DuckDBDriver implements java.sql.Driver { @@ -31,6 +43,16 @@ public class DuckDBDriver implements java.sql.Driver { private static final Set supportedOptions = new LinkedHashSet<>(); private static final ReentrantLock supportedOptionsLock = new ReentrantLock(); + private static final String SESSION_INIT_SQL_FILE_OPTION = "session_init_sql_file"; + private static final String SESSION_INIT_SQL_FILE_SHA256_OPTION = "session_init_sql_file_sha256"; + private static final long SESSION_INIT_SQL_FILE_MAX_SIZE_BYTES = 1 << 20; // 1MB + private static final String SESSION_INIT_SQL_FILE_URL_EXAMPLE = + "jdbc:duckdb:/path/to/db1.db;session_init_sql_file=/path/to/init.sql;session_init_sql_file_sha256=..."; + private static final String SESSION_INIT_SQL_CONN_INIT_MARKER = + "/\\*\\s*DUCKDB_CONNECTION_INIT_BELOW_MARKER\\s*\\*/"; + private static final LinkedHashSet sessionInitSQLFileDbNames = new LinkedHashSet<>(); + private static final ReentrantLock sessionInitSQLFileLock = new ReentrantLock(); + static { try { DriverManager.registerDriver(new DuckDBDriver()); @@ -56,6 +78,9 @@ public Connection connect(String url, Properties info) throws SQLException { // URL options ParsedProps pp = parsePropsFromUrl(url); + // Read session init file + SessionInitSQLFile sf = readSessionInitSQLFile(pp); + // Options in URL take preference for (Map.Entry en : pp.props.entrySet()) { props.put(en.getKey(), en.getValue()); @@ -82,11 +107,12 @@ public Connection connect(String url, Properties info) throws SQLException { boolean pinDBOpt = isStringTruish(pinDbOptStr, false); // Create connection - DuckDBConnection conn = DuckDBConnection.newConnection(pp.shortUrl, readOnly, props); + DuckDBConnection conn = DuckDBConnection.newConnection(pp.shortUrl, readOnly, sf.origFileText, props); // Run post-init try { pinDB(pinDBOpt, pp.shortUrl, conn); + runSessionInitSQLFile(conn, pp.shortUrl, sf); } catch (SQLException e) { closeQuietly(conn); throw e; @@ -143,6 +169,7 @@ private static ParsedProps parsePropsFromUrl(String url) throws SQLException { } String[] parts = url.split(";"); LinkedHashMap props = new LinkedHashMap<>(); + List origPropNames = new ArrayList<>(); for (int i = 1; i < parts.length; i++) { String entry = parts[i].trim(); if (entry.isEmpty()) { @@ -154,10 +181,11 @@ private static ParsedProps parsePropsFromUrl(String url) throws SQLException { } String key = kv[0].trim(); String value = kv[1].trim(); + origPropNames.add(key); props.put(key, value); } String shortUrl = parts[0].trim(); - return new ParsedProps(shortUrl, props); + return new ParsedProps(shortUrl, props, origPropNames); } private static void pinDB(boolean pinnedDbOpt, String url, DuckDBConnection conn) throws SQLException { @@ -247,17 +275,124 @@ private static void removeUnsupportedOptions(Properties props) throws SQLExcepti } } + private static SessionInitSQLFile readSessionInitSQLFile(ParsedProps pp) throws SQLException { + if (!pp.props.containsKey(SESSION_INIT_SQL_FILE_OPTION)) { + return new SessionInitSQLFile(); + } + + List urlOptsList = new ArrayList<>(pp.props.keySet()); + + if (!SESSION_INIT_SQL_FILE_OPTION.equals(urlOptsList.get(0))) { + throw new SQLException( + "'session_init_sql_file' can only be specified as the first parameter in connection string," + + " example: '" + SESSION_INIT_SQL_FILE_URL_EXAMPLE + "'"); + } + for (int i = 1; i < pp.origPropNames.size(); i++) { + if (SESSION_INIT_SQL_FILE_OPTION.equalsIgnoreCase(pp.origPropNames.get(i))) { + throw new SQLException("'session_init_sql_file' option cannot be specified more than once"); + } + } + String filePathStr = pp.props.remove(SESSION_INIT_SQL_FILE_OPTION); + + final String expectedSha256; + if (pp.props.containsKey(SESSION_INIT_SQL_FILE_SHA256_OPTION)) { + if (!SESSION_INIT_SQL_FILE_SHA256_OPTION.equals(urlOptsList.get(1))) { + throw new SQLException( + "'session_init_sql_file_sha256' can only be specified as the second parameter in connection string," + + " example: '" + SESSION_INIT_SQL_FILE_URL_EXAMPLE + "'"); + } + for (int i = 2; i < pp.origPropNames.size(); i++) { + if (SESSION_INIT_SQL_FILE_SHA256_OPTION.equalsIgnoreCase(pp.origPropNames.get(i))) { + throw new SQLException("'session_init_sql_file_sha256' option cannot be specified more than once"); + } + } + expectedSha256 = pp.props.remove(SESSION_INIT_SQL_FILE_SHA256_OPTION); + } else { + expectedSha256 = ""; + } + + Path filePath = Paths.get(filePathStr); + if (!Files.exists(filePath)) { + throw new SQLException("Specified session init SQL file not found, path: " + filePath); + } + + final String origFileText; + final String actualSha256; + try { + long fileSize = Files.size(filePath); + if (fileSize > SESSION_INIT_SQL_FILE_MAX_SIZE_BYTES) { + throw new SQLException("Specified session init SQL file size: " + fileSize + + " exceeds max allowed size: " + SESSION_INIT_SQL_FILE_MAX_SIZE_BYTES); + } + MessageDigest md = MessageDigest.getInstance("SHA-256"); + try (InputStream is = new DigestInputStream( + new LimitedInputStream(Files.newInputStream(filePath, READ), fileSize), md)) { + Reader reader = new InputStreamReader(is, UTF_8); + origFileText = readToString(reader); + actualSha256 = bytesToHex(md.digest()); + } + } catch (Exception e) { + throw new SQLException(e); + } + + if (!expectedSha256.isEmpty() && !expectedSha256.toLowerCase().equals(actualSha256)) { + throw new SQLException("Session init SQL file SHA-256 mismatch, expected: " + expectedSha256 + + ", actual: " + actualSha256); + } + + String[] parts = origFileText.split(SESSION_INIT_SQL_CONN_INIT_MARKER); + if (parts.length > 2) { + throw new SQLException("Connection init marker: '" + SESSION_INIT_SQL_CONN_INIT_MARKER + + "' can only be specified once"); + } + if (1 == parts.length) { + return new SessionInitSQLFile(origFileText, parts[0].trim()); + } else { + return new SessionInitSQLFile(origFileText, parts[0].trim(), parts[1].trim()); + } + } + + private static void runSessionInitSQLFile(Connection conn, String url, SessionInitSQLFile sf) throws SQLException { + if (sf.isEmpty()) { + return; + } + sessionInitSQLFileLock.lock(); + try { + + if (!sf.dbInitSQL.isEmpty()) { + String dbName = dbNameFromUrl(url); + if (MEMORY_DB.equals(dbName) || !sessionInitSQLFileDbNames.contains(dbName)) { + try (Statement stmt = conn.createStatement()) { + stmt.execute(sf.dbInitSQL); + } + } + sessionInitSQLFileDbNames.add(dbName); + } + + if (!sf.connInitSQL.isEmpty()) { + try (Statement stmt = conn.createStatement()) { + stmt.execute(sf.connInitSQL); + } + } + + } finally { + sessionInitSQLFileLock.unlock(); + } + } + private static class ParsedProps { final String shortUrl; final LinkedHashMap props; + final List origPropNames; private ParsedProps(String url) { - this(url, new LinkedHashMap<>()); + this(url, new LinkedHashMap<>(), new ArrayList<>()); } - private ParsedProps(String shortUrl, LinkedHashMap props) { + private ParsedProps(String shortUrl, LinkedHashMap props, List origPropNames) { this.shortUrl = shortUrl; this.props = props; + this.origPropNames = origPropNames; } } @@ -279,4 +414,28 @@ public void run() { } } } + + private static class SessionInitSQLFile { + final String dbInitSQL; + final String connInitSQL; + final String origFileText; + + private SessionInitSQLFile() { + this(null, null, null); + } + + private SessionInitSQLFile(String origFileText, String dbInitSQL) { + this(origFileText, dbInitSQL, ""); + } + + private SessionInitSQLFile(String origFileText, String dbInitSQL, String connInitSQL) { + this.origFileText = origFileText; + this.dbInitSQL = dbInitSQL; + this.connInitSQL = connInitSQL; + } + + boolean isEmpty() { + return null == dbInitSQL && null == connInitSQL && null == origFileText; + } + } } diff --git a/src/main/java/org/duckdb/JdbcUtils.java b/src/main/java/org/duckdb/JdbcUtils.java index c46b9052f..2f396f42e 100644 --- a/src/main/java/org/duckdb/JdbcUtils.java +++ b/src/main/java/org/duckdb/JdbcUtils.java @@ -19,7 +19,6 @@ static T unwrap(Object obj, Class iface) throws SQLException { return (T) obj; } - static String removeOption(Properties props, String opt) { return removeOption(props, opt, null); } diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 22ca08ad0..7f46b28f8 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -3638,7 +3638,7 @@ public static void main(String[] args) throws Exception { } else { statusCode = runTests(args, TestDuckDBJDBC.class, TestBatch.class, TestClosure.class, TestExtensionTypes.class, TestSpatial.class, TestParameterMetadata.class, - TestPrepare.class, TestResults.class, TestTimestamp.class); + TestPrepare.class, TestResults.class, TestSessionInit.class, TestTimestamp.class); } System.exit(statusCode); } diff --git a/src/test/java/org/duckdb/TestSessionInit.java b/src/test/java/org/duckdb/TestSessionInit.java new file mode 100644 index 000000000..961435883 --- /dev/null +++ b/src/test/java/org/duckdb/TestSessionInit.java @@ -0,0 +1,161 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.JdbcUtils.bytesToHex; +import static org.duckdb.test.Assertions.*; + +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.DigestOutputStream; +import java.security.MessageDigest; +import java.sql.*; +import java.util.Properties; +import org.duckdb.test.TempDirectory; + +public class TestSessionInit { + + public static void test_session_init_db_only() throws Exception { + try (TempDirectory td = new TempDirectory()) { + Path initSqlFile = td.path().resolve("init.sql"); + Files.write(initSqlFile, "CREATE TABLE tab1(col1 int);".getBytes()); + try (Connection conn = DriverManager.getConnection("jdbc:duckdb:;session_init_sql_file=" + initSqlFile); + Statement stmt = conn.createStatement()) { + stmt.execute("DROP TABLE tab1"); + } + } + } + + public static void test_session_init_db_and_connection() throws Exception { + try (TempDirectory td = new TempDirectory()) { + Path initSqlFile = td.path().resolve("init.sql"); + Files.write(initSqlFile, ("CREATE TABLE tab1(col1 int);\n" + + " /* DUCKDB_CONNECTION_INIT_BELOW_MARKER */ \n" + + "INSERT INTO tab1 VALUES(42);") + .getBytes()); + String url = "jdbc:duckdb:memory:test1;session_init_sql_file=" + initSqlFile; + try (Connection conn1 = DriverManager.getConnection(url); + Connection conn2 = DriverManager.getConnection(url); Statement stmt = conn2.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM tab1")) { + rs.next(); + assertEquals(rs.getInt(1), 42); + rs.next(); + assertEquals(rs.getInt(1), 42); + } + } + } + + public static void test_session_init_connection_only() throws Exception { + try (TempDirectory td = new TempDirectory()) { + Path initSqlFile = td.path().resolve("init.sql"); + Files.write(initSqlFile, (" /* DUCKDB_CONNECTION_INIT_BELOW_MARKER */ \n" + + "CREATE TABLE tab1(col1 int)") + .getBytes()); + try (Connection conn = DriverManager.getConnection("jdbc:duckdb:;session_init_sql_file=" + initSqlFile); + Statement stmt = conn.createStatement()) { + stmt.execute("DROP TABLE tab1"); + } + } + } + + public static void test_session_init_sha256() throws Exception { + try (TempDirectory td = new TempDirectory()) { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + Path initSqlFile = td.path().resolve("init.sql"); + + final DigestOutputStream dos; + try (OutputStream os = Files.newOutputStream(initSqlFile)) { + dos = new DigestOutputStream(os, md); + dos.write("CREATE TABLE tab1(col1 int)".getBytes(UTF_8)); + dos.flush(); + } + + String sha256 = bytesToHex(md.digest()); + try (Connection conn = DriverManager.getConnection("jdbc:duckdb:;" + + "session_init_sql_file=" + initSqlFile + ";" + + "session_init_sql_file_sha256=" + sha256); + Statement stmt = conn.createStatement()) { + stmt.execute("DROP TABLE tab1"); + } + + assertThrows(() -> { + DriverManager.getConnection("jdbc:duckdb:;" + + "session_init_sql_file=" + initSqlFile + ";" + + "session_init_sql_file_sha25=fail"); + }, SQLException.class); + + assertThrows(() -> { + DriverManager.getConnection("jdbc:duckdb:" + + "session_init_sql_file=" + initSqlFile + ";" + + "threads=1;" + + "session_init_sql_file_sha256=" + sha256); + }, SQLException.class); + + assertThrows(() -> { + DriverManager.getConnection("jdbc:duckdb:;" + + "session_init_sql_file=" + initSqlFile + ";" + + "session_init_sql_file_sha256=" + sha256 + ";" + + "session_init_sql_file_sha256=" + sha256); + }, SQLException.class); + + assertThrows(() -> { + Properties config = new Properties(); + config.put("session_init_sql_file_sha256", sha256); + DriverManager.getConnection("jdbc:duckdb:;" + + "session_init_sql_file=" + initSqlFile, + config); + }, SQLException.class); + } + } + + public static void test_session_init_tracing() throws Exception { + String sql = " CREATE TABLE tab1(col1 int)\n\n"; + try (TempDirectory td = new TempDirectory()) { + Path initSqlFile = td.path().resolve("init.sql"); + Files.write(initSqlFile, sql.getBytes()); + try (DuckDBConnection conn = + DriverManager.getConnection("jdbc:duckdb:;session_init_sql_file=" + initSqlFile) + .unwrap(DuckDBConnection.class)) { + assertEquals(conn.getSessionInitSQL(), sql); + } + } + } + + public static void test_session_init_invalid_params() throws Exception { + try (TempDirectory td = new TempDirectory()) { + Path initSqlFile = td.path().resolve("init.sql"); + Files.write(initSqlFile, "CREATE TABLE tab1(col1 int);".getBytes()); + DriverManager.getConnection("jdbc:duckdb:;session_init_sql_file=" + initSqlFile).close(); + assertThrows(() -> { + DriverManager.getConnection("jdbc:duckdb:;" + + "threads=1;" + + "session_init_sql_file=" + initSqlFile); + }, SQLException.class); + assertThrows(() -> { + DriverManager.getConnection("jdbc:duckdb:;" + + "session_init_sql_file=" + initSqlFile + ";" + + "session_init_sql_file=" + initSqlFile); + }, SQLException.class); + assertThrows(() -> { + Properties config = new Properties(); + config.put("session_init_sql_file", "initSqlFile"); + DriverManager.getConnection("jdbc:duckdb:;", config); + }, SQLException.class); + } + } + + public static void test_session_init_invalid_file() throws Exception { + try (TempDirectory td = new TempDirectory()) { + Path initSqlFile = td.path().resolve("init.sql"); + Files.write(initSqlFile, ("CREATE TABLE tab1(col1 int);\n" + + " /* DUCKDB_CONNECTION_INIT_BELOW_MARKER */ \n" + + "INSERT INTO tab1 VALUES(42);\n" + + " /* DUCKDB_CONNECTION_INIT_BELOW_MARKER */ \n" + + "INSERT INTO tab1 VALUES(43);\n") + .getBytes()); + assertThrows(() -> { + DriverManager.getConnection("jdbc:duckdb:;session_init_sql_file=" + initSqlFile); + }, SQLException.class); + } + } +} diff --git a/src/test/java/org/duckdb/test/TempDirectory.java b/src/test/java/org/duckdb/test/TempDirectory.java new file mode 100644 index 000000000..9c4d60ca2 --- /dev/null +++ b/src/test/java/org/duckdb/test/TempDirectory.java @@ -0,0 +1,32 @@ +package org.duckdb.test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Comparator; + +public class TempDirectory implements AutoCloseable { + private final Path tempDir; + + public TempDirectory() throws IOException { + this.tempDir = Files.createTempDirectory("duckdb_tempdir_"); + } + + public Path path() { + return tempDir; + } + + @Override + public void close() throws IOException { + // Recursively delete the directory and its contents + if (Files.exists(tempDir)) { + Files.walk(tempDir).sorted(Comparator.reverseOrder()).forEach(p -> { + try { + Files.delete(p); + } catch (IOException e) { + throw new RuntimeException("Failed to delete " + p, e); + } + }); + } + } +}