Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/main/java/org/duckdb/DuckDBConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,31 @@ public final class DuckDBConnection implements java.sql.Connection {
volatile boolean transactionRunning;
final String url;
private final boolean readOnly;
private final String sessionInitSQL;

public static DuckDBConnection newConnection(String url, boolean readOnly, Properties properties)
throws SQLException {
public static DuckDBConnection newConnection(String url, boolean readOnly, Properties properties) throws Exception {
return newConnection(url, readOnly, null, properties);
}

public static DuckDBConnection newConnection(String url, boolean readOnly, String sessionInitSQL,
Properties properties) throws SQLException {
if (null == properties) {
properties = new Properties();
}
String dbName = dbNameFromUrl(url);
String autoCommitStr = removeOption(properties, JDBC_AUTO_COMMIT);
boolean autoCommit = isStringTruish(autoCommitStr, true);
ByteBuffer nativeReference = DuckDBNative.duckdb_jdbc_startup(dbName.getBytes(UTF_8), readOnly, properties);
return new DuckDBConnection(nativeReference, url, readOnly, autoCommit);
return new DuckDBConnection(nativeReference, url, readOnly, sessionInitSQL, autoCommit);
}

private DuckDBConnection(ByteBuffer connectionReference, String url, boolean readOnly, boolean autoCommit)
throws SQLException {
private DuckDBConnection(ByteBuffer connectionReference, String url, boolean readOnly, String sessionInitSQL,
boolean autoCommit) throws SQLException {
this.connRef = connectionReference;
this.url = url;
this.readOnly = readOnly;
this.autoCommit = autoCommit;
this.sessionInitSQL = sessionInitSQL;
// Hardcoded 'true' here is intentional, autocommit is handled in stmt#execute()
DuckDBNative.duckdb_jdbc_set_auto_commit(connectionReference, true);
}
Expand Down Expand Up @@ -95,7 +101,8 @@ public Connection duplicate() throws SQLException {
connRefLock.lock();
try {
checkOpen();
return new DuckDBConnection(DuckDBNative.duckdb_jdbc_connect(connRef), url, readOnly, autoCommit);
return new DuckDBConnection(DuckDBNative.duckdb_jdbc_connect(connRef), url, readOnly, sessionInitSQL,
autoCommit);
} finally {
connRefLock.unlock();
}
Expand Down Expand Up @@ -478,6 +485,10 @@ public DuckDBHugeInt createHugeInt(long lower, long upper) throws SQLException {
return new DuckDBHugeInt(lower, upper);
}

public String getSessionInitSQL() throws SQLException {
return sessionInitSQL;
}

void checkOpen() throws SQLException {
if (isClosed()) {
throw new SQLException("Connection was closed");
Expand Down
176 changes: 169 additions & 7 deletions src/main/java/org/duckdb/DuckDBDriver.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
package org.duckdb;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.nio.file.StandardOpenOption.READ;
import static org.duckdb.JdbcUtils.*;
import static org.duckdb.io.IOUtils.readToString;

import java.io.*;
import java.nio.ByteBuffer;
import java.nio.file.*;
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.sql.*;
import java.util.*;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import org.duckdb.io.LimitedInputStream;

public class DuckDBDriver implements java.sql.Driver {

Expand Down Expand Up @@ -41,6 +49,16 @@ public class DuckDBDriver implements java.sql.Driver {
private static final Set<String> supportedOptions = new LinkedHashSet<>();
private static final ReentrantLock supportedOptionsLock = new ReentrantLock();

private static final String SESSION_INIT_SQL_FILE_OPTION = "session_init_sql_file";
private static final String SESSION_INIT_SQL_FILE_SHA256_OPTION = "session_init_sql_file_sha256";
private static final long SESSION_INIT_SQL_FILE_MAX_SIZE_BYTES = 1 << 20; // 1MB
private static final String SESSION_INIT_SQL_FILE_URL_EXAMPLE =
"jdbc:duckdb:/path/to/db1.db;session_init_sql_file=/path/to/init.sql;session_init_sql_file_sha256=...";
private static final String SESSION_INIT_SQL_CONN_INIT_MARKER =
"/\\*\\s*DUCKDB_CONNECTION_INIT_BELOW_MARKER\\s*\\*/";
private static final LinkedHashSet<String> sessionInitSQLFileDbNames = new LinkedHashSet<>();
private static final ReentrantLock sessionInitSQLFileLock = new ReentrantLock();

static {
try {
DriverManager.registerDriver(new DuckDBDriver());
Expand All @@ -65,6 +83,11 @@ public Connection connect(String url, Properties info) throws SQLException {

// URL options
ParsedProps pp = parsePropsFromUrl(url);

// Read session init file
SessionInitSQLFile sf = readSessionInitSQLFile(pp);

// Options in URL take preference
for (Map.Entry<String, String> en : pp.props.entrySet()) {
props.put(en.getKey(), en.getValue());
}
Expand Down Expand Up @@ -107,11 +130,17 @@ public Connection connect(String url, Properties info) throws SQLException {
boolean pinDBOpt = isStringTruish(pinDbOptStr, false);

// Create connection
DuckDBConnection conn = DuckDBConnection.newConnection(shortUrl, readOnly, props);

pinDB(pinDBOpt, shortUrl, conn);
DuckDBConnection conn = DuckDBConnection.newConnection(shortUrl, readOnly, sf.origFileText, props);

initDucklake(conn, shortUrl, ducklake, ducklakeAlias);
// Run post-init
try {
pinDB(pinDBOpt, shortUrl, conn);
runSessionInitSQLFile(conn, url, sf);
initDucklake(conn, shortUrl, ducklake, ducklakeAlias);
} catch (SQLException e) {
closeQuietly(conn);
throw e;
}

return conn;
}
Expand Down Expand Up @@ -202,6 +231,7 @@ private static ParsedProps parsePropsFromUrl(String url) throws SQLException {
}
String[] parts = url.split(";");
LinkedHashMap<String, String> props = new LinkedHashMap<>();
List<String> origPropNames = new ArrayList<>();
for (int i = 1; i < parts.length; i++) {
String entry = parts[i].trim();
if (entry.isEmpty()) {
Expand All @@ -213,10 +243,11 @@ private static ParsedProps parsePropsFromUrl(String url) throws SQLException {
}
String key = kv[0].trim();
String value = kv[1].trim();
origPropNames.add(key);
props.put(key, value);
}
String shortUrl = parts[0].trim();
return new ParsedProps(shortUrl, props);
return new ParsedProps(shortUrl, props, origPropNames);
}

private static void pinDB(boolean pinnedDbOpt, String url, DuckDBConnection conn) throws SQLException {
Expand Down Expand Up @@ -306,17 +337,124 @@ private static void removeUnsupportedOptions(Properties props) throws SQLExcepti
}
}

private static SessionInitSQLFile readSessionInitSQLFile(ParsedProps pp) throws SQLException {
if (!pp.props.containsKey(SESSION_INIT_SQL_FILE_OPTION)) {
return new SessionInitSQLFile();
}

List<String> urlOptsList = new ArrayList<>(pp.props.keySet());

if (!SESSION_INIT_SQL_FILE_OPTION.equals(urlOptsList.get(0))) {
throw new SQLException(
"'session_init_sql_file' can only be specified as the first parameter in connection string,"
+ " example: '" + SESSION_INIT_SQL_FILE_URL_EXAMPLE + "'");
}
for (int i = 1; i < pp.origPropNames.size(); i++) {
if (SESSION_INIT_SQL_FILE_OPTION.equalsIgnoreCase(pp.origPropNames.get(i))) {
throw new SQLException("'session_init_sql_file' option cannot be specified more than once");
}
}
String filePathStr = pp.props.remove(SESSION_INIT_SQL_FILE_OPTION);

final String expectedSha256;
if (pp.props.containsKey(SESSION_INIT_SQL_FILE_SHA256_OPTION)) {
if (!SESSION_INIT_SQL_FILE_SHA256_OPTION.equals(urlOptsList.get(1))) {
throw new SQLException(
"'session_init_sql_file_sha256' can only be specified as the second parameter in connection string,"
+ " example: '" + SESSION_INIT_SQL_FILE_URL_EXAMPLE + "'");
}
for (int i = 2; i < pp.origPropNames.size(); i++) {
if (SESSION_INIT_SQL_FILE_SHA256_OPTION.equalsIgnoreCase(pp.origPropNames.get(i))) {
throw new SQLException("'session_init_sql_file_sha256' option cannot be specified more than once");
}
}
expectedSha256 = pp.props.remove(SESSION_INIT_SQL_FILE_SHA256_OPTION);
} else {
expectedSha256 = "";
}

Path filePath = Paths.get(filePathStr);
if (!Files.exists(filePath)) {
throw new SQLException("Specified session init SQL file not found, path: " + filePath);
}

final String origFileText;
final String actualSha256;
try {
long fileSize = Files.size(filePath);
if (fileSize > SESSION_INIT_SQL_FILE_MAX_SIZE_BYTES) {
throw new SQLException("Specified session init SQL file size: " + fileSize +
" exceeds max allowed size: " + SESSION_INIT_SQL_FILE_MAX_SIZE_BYTES);
}
MessageDigest md = MessageDigest.getInstance("SHA-256");
try (InputStream is = new DigestInputStream(
new LimitedInputStream(Files.newInputStream(filePath, READ), fileSize), md)) {
Reader reader = new InputStreamReader(is, UTF_8);
origFileText = readToString(reader);
actualSha256 = bytesToHex(md.digest());
}
} catch (Exception e) {
throw new SQLException(e);
}

if (!expectedSha256.isEmpty() && !expectedSha256.toLowerCase().equals(actualSha256)) {
throw new SQLException("Session init SQL file SHA-256 mismatch, expected: " + expectedSha256 +
", actual: " + actualSha256);
}

String[] parts = origFileText.split(SESSION_INIT_SQL_CONN_INIT_MARKER);
if (parts.length > 2) {
throw new SQLException("Connection init marker: '" + SESSION_INIT_SQL_CONN_INIT_MARKER +
"' can only be specified once");
}
if (1 == parts.length) {
return new SessionInitSQLFile(origFileText, parts[0].trim());
} else {
return new SessionInitSQLFile(origFileText, parts[0].trim(), parts[1].trim());
}
}

private static void runSessionInitSQLFile(Connection conn, String url, SessionInitSQLFile sf) throws SQLException {
if (sf.isEmpty()) {
return;
}
sessionInitSQLFileLock.lock();
try {

if (!sf.dbInitSQL.isEmpty()) {
String dbName = dbNameFromUrl(url);
if (MEMORY_DB.equals(dbName) || !sessionInitSQLFileDbNames.contains(dbName)) {
try (Statement stmt = conn.createStatement()) {
stmt.execute(sf.dbInitSQL);
}
}
sessionInitSQLFileDbNames.add(dbName);
}

if (!sf.connInitSQL.isEmpty()) {
try (Statement stmt = conn.createStatement()) {
stmt.execute(sf.connInitSQL);
}
}

} finally {
sessionInitSQLFileLock.unlock();
}
}

private static class ParsedProps {
final String shortUrl;
final LinkedHashMap<String, String> props;
final List<String> origPropNames;

private ParsedProps(String url) {
this(url, new LinkedHashMap<>());
this(url, new LinkedHashMap<>(), new ArrayList<>());
}

private ParsedProps(String shortUrl, LinkedHashMap<String, String> props) {
private ParsedProps(String shortUrl, LinkedHashMap<String, String> props, List<String> origPropNames) {
this.shortUrl = shortUrl;
this.props = props;
this.origPropNames = origPropNames;
}
}

Expand All @@ -338,4 +476,28 @@ public void run() {
}
}
}

private static class SessionInitSQLFile {
final String dbInitSQL;
final String connInitSQL;
final String origFileText;

private SessionInitSQLFile() {
this(null, null, null);
}

private SessionInitSQLFile(String origFileText, String dbInitSQL) {
this(origFileText, dbInitSQL, "");
}

private SessionInitSQLFile(String origFileText, String dbInitSQL, String connInitSQL) {
this.origFileText = origFileText;
this.dbInitSQL = dbInitSQL;
this.connInitSQL = connInitSQL;
}

boolean isEmpty() {
return null == dbInitSQL && null == connInitSQL && null == origFileText;
}
}
}
22 changes: 22 additions & 0 deletions src/main/java/org/duckdb/JdbcUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,26 @@ static String dbNameFromUrl(String url) throws SQLException {
}
return dbName;
}

static String bytesToHex(byte[] bytes) {
if (null == bytes) {
return "";
}
StringBuilder sb = new StringBuilder(bytes.length * 2);
for (byte b : bytes) {
sb.append(String.format("%02x", b));
}
return sb.toString();
}

static void closeQuietly(AutoCloseable closeable) {
if (null == closeable) {
return;
}
try {
closeable.close();
} catch (Exception e) {
// suppress
}
}
}
2 changes: 1 addition & 1 deletion src/test/java/org/duckdb/TestDuckDBJDBC.java
Original file line number Diff line number Diff line change
Expand Up @@ -3674,7 +3674,7 @@ public static void main(String[] args) throws Exception {
} else {
statusCode = runTests(args, TestDuckDBJDBC.class, TestBatch.class, TestClosure.class,
TestExtensionTypes.class, TestSpatial.class, TestParameterMetadata.class,
TestPrepare.class, TestResults.class, TestTimestamp.class);
TestPrepare.class, TestResults.class, TestSessionInit.class, TestTimestamp.class);
}
System.exit(statusCode);
}
Expand Down
Loading
Loading