diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java index f7a01128cb..0192086355 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java @@ -108,14 +108,14 @@ public static ImmutableMap buildPropertiesMap( if (!isNullOrEmpty(connectionParamString)) { String[] urlParts = connectionParamString.split(DatabricksJdbcConstants.URL_DELIMITER); for (String urlPart : urlParts) { - String[] pair = urlPart.split(DatabricksJdbcConstants.PAIR_DELIMITER); - if (pair.length == 1) { - pair = new String[] {pair[0], ""}; - } - if (pair[0].startsWith(DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName())) { - parametersBuilder.put(pair[0], pair[1]); + // Split on first '=' only — values (like httpPath) may contain '=' (e.g. ?o=123) + int delimIdx = urlPart.indexOf(DatabricksJdbcConstants.PAIR_DELIMITER); + String key = delimIdx >= 0 ? urlPart.substring(0, delimIdx) : urlPart; + String value = delimIdx >= 0 ? urlPart.substring(delimIdx + 1) : ""; + if (key.startsWith(DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName())) { + parametersBuilder.put(key, value); } else { - parametersBuilder.put(pair[0].toLowerCase(), pair[1]); + parametersBuilder.put(key.toLowerCase(), value); } } } @@ -1167,14 +1167,39 @@ private String getParameter(DatabricksJdbcUrlParams key, String defaultValue) { return this.parameters.getOrDefault(key.getParamName().toLowerCase(), defaultValue); } + private static final String ORG_ID_HEADER = "x-databricks-org-id"; + private Map parseCustomHeaders(ImmutableMap parameters) { String filterPrefix = DatabricksJdbcUrlParams.HTTP_HEADERS.getParamName(); - return parameters.entrySet().stream() - .filter(entry -> entry.getKey().startsWith(filterPrefix)) - .collect( - Collectors.toMap( - entry -> entry.getKey().substring(filterPrefix.length()), Map.Entry::getValue)); + Map headers = + new HashMap<>( + parameters.entrySet().stream() + .filter(entry -> entry.getKey().startsWith(filterPrefix)) + .collect( + Collectors.toMap( + entry -> entry.getKey().substring(filterPrefix.length()), + Map.Entry::getValue))); + + // Extract org ID from ?o= in httpPath for SPOG routing + if (!headers.containsKey(ORG_ID_HEADER)) { + String httpPath = + parameters.getOrDefault( + DatabricksJdbcUrlParams.HTTP_PATH.getParamName().toLowerCase(), ""); + int queryStart = httpPath.indexOf('?'); + if (queryStart >= 0) { + String queryString = httpPath.substring(queryStart + 1); + for (String param : queryString.split("&")) { + String[] kv = param.split("=", 2); + if (kv.length == 2 && "o".equals(kv[0]) && !kv[1].isEmpty()) { + headers.put(ORG_ID_HEADER, kv[1]); + break; + } + } + } + } + + return headers; } @Override diff --git a/src/main/java/com/databricks/jdbc/api/impl/volume/DBFSVolumeClient.java b/src/main/java/com/databricks/jdbc/api/impl/volume/DBFSVolumeClient.java index 9fa8d8c395..a3e11cfbe0 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/volume/DBFSVolumeClient.java +++ b/src/main/java/com/databricks/jdbc/api/impl/volume/DBFSVolumeClient.java @@ -492,7 +492,8 @@ CreateUploadUrlResponse getCreateUploadUrlResponse(String objectPath) CreateUploadUrlRequest request = new CreateUploadUrlRequest(objectPath); try { Request req = new Request(Request.POST, CREATE_UPLOAD_URL_PATH, apiClient.serialize(request)); - req.withHeaders(JSON_HTTP_HEADERS); + req.withHeaders(JSON_HTTP_HEADERS) + .withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of()); return apiClient.execute(req, CreateUploadUrlResponse.class); } catch (IOException | DatabricksException e) { String errorMessage = @@ -514,7 +515,8 @@ CreateDownloadUrlResponse getCreateDownloadUrlResponse(String objectPath) try { Request req = new Request(Request.POST, CREATE_DOWNLOAD_URL_PATH, apiClient.serialize(request)); - req.withHeaders(JSON_HTTP_HEADERS); + req.withHeaders(JSON_HTTP_HEADERS) + .withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of()); return apiClient.execute(req, CreateDownloadUrlResponse.class); } catch (IOException | DatabricksException e) { String errorMessage = @@ -534,7 +536,8 @@ CreateDeleteUrlResponse getCreateDeleteUrlResponse(String objectPath) try { Request req = new Request(Request.POST, CREATE_DELETE_URL_PATH, apiClient.serialize(request)); - req.withHeaders(JSON_HTTP_HEADERS); + req.withHeaders(JSON_HTTP_HEADERS) + .withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of()); return apiClient.execute(req, CreateDeleteUrlResponse.class); } catch (IOException | DatabricksException e) { String errorMessage = @@ -551,7 +554,8 @@ ListResponse getListResponse(String listPath) throws DatabricksVolumeOperationEx ListRequest request = new ListRequest(listPath); try { Request req = new Request(Request.GET, LIST_PATH); - req.withHeaders(JSON_HTTP_HEADERS); + req.withHeaders(JSON_HTTP_HEADERS) + .withHeaders(connectionContext != null ? connectionContext.getCustomHeaders() : Map.of()); ApiClient.setQuery(req, request); return apiClient.execute(req, ListResponse.class); } catch (IOException | DatabricksException e) { @@ -888,6 +892,9 @@ private CompletableFuture requestPresignedUrlWithRetry( Map authHeaders = workspaceClient.config().authenticate(); authHeaders.forEach(requestBuilder::addHeader); JSON_HTTP_HEADERS.forEach(requestBuilder::addHeader); + if (connectionContext != null) { + connectionContext.getCustomHeaders().forEach(requestBuilder::addHeader); + } requestBuilder.setEntity( AsyncEntityProducers.create(requestBody.getBytes(), ContentType.APPLICATION_JSON)); diff --git a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java index 6eb8ef12ab..7849411f52 100644 --- a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java +++ b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcConstants.java @@ -17,8 +17,10 @@ public final class DatabricksJdbcConstants { "(?:/([^;]*))?" + // Optional Schema (captured without /) "(?:;(.*))?"); // Optional Property=Value pairs (captured without leading ;) - public static final Pattern HTTP_WAREHOUSE_PATH_PATTERN = Pattern.compile(".*/warehouses/(.+)"); - public static final Pattern HTTP_ENDPOINT_PATH_PATTERN = Pattern.compile(".*/endpoints/(.+)"); + public static final Pattern HTTP_WAREHOUSE_PATH_PATTERN = + Pattern.compile(".*/warehouses/([^?&]+).*"); + public static final Pattern HTTP_ENDPOINT_PATH_PATTERN = + Pattern.compile(".*/endpoints/([^?&]+).*"); public static final Pattern HTTP_CLI_PATTERN = Pattern.compile(".*cliservice(.+)"); public static final Pattern HTTP_PATH_CLI_PATTERN = Pattern.compile("cliservice"); public static final Pattern TEST_PATH_PATTERN = Pattern.compile("jdbc:databricks://test"); diff --git a/src/main/java/com/databricks/jdbc/common/safe/DatabricksDriverFeatureFlagsContext.java b/src/main/java/com/databricks/jdbc/common/safe/DatabricksDriverFeatureFlagsContext.java index cfcd594087..8737fb77ee 100644 --- a/src/main/java/com/databricks/jdbc/common/safe/DatabricksDriverFeatureFlagsContext.java +++ b/src/main/java/com/databricks/jdbc/common/safe/DatabricksDriverFeatureFlagsContext.java @@ -102,6 +102,7 @@ private void refreshAllFeatureFlags() { .getDatabricksConfig() .authenticate() .forEach(request::addHeader); + connectionContext.getCustomHeaders().forEach(request::addHeader); fetchAndSetFlagsFromServer(httpClient, request); } catch (Exception e) { LOGGER.trace( diff --git a/src/main/java/com/databricks/jdbc/telemetry/TelemetryPushClient.java b/src/main/java/com/databricks/jdbc/telemetry/TelemetryPushClient.java index 1befdec99f..4ae719e906 100644 --- a/src/main/java/com/databricks/jdbc/telemetry/TelemetryPushClient.java +++ b/src/main/java/com/databricks/jdbc/telemetry/TelemetryPushClient.java @@ -59,6 +59,7 @@ public void pushEvent(TelemetryRequest request) throws Exception { Map authHeaders = isAuthenticated ? databricksConfig.authenticate() : Collections.emptyMap(); authHeaders.forEach(post::addHeader); + connectionContext.getCustomHeaders().forEach(post::addHeader); try (CloseableHttpResponse response = httpClient.execute(post)) { // TODO: check response and add retry for partial failures if (!HttpUtil.isSuccessfulHttpResponse(response)) { diff --git a/src/test/java/com/databricks/jdbc/TestConstants.java b/src/test/java/com/databricks/jdbc/TestConstants.java index 5e51a4ab81..4e800b075c 100644 --- a/src/test/java/com/databricks/jdbc/TestConstants.java +++ b/src/test/java/com/databricks/jdbc/TestConstants.java @@ -321,4 +321,17 @@ public class TestConstants { public static final List ARROW_BATCH_LIST = Collections.singletonList( new TSparkArrowBatch().setRowCount(0).setBatch(new byte[] {65, 66, 67})); + + // SPOG URLs with ?o= query parameter in httpPath + public static final String VALID_SPOG_URL_WAREHOUSE = + "jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;" + + "httpPath=/sql/1.0/warehouses/abc123?o=6051921418418893;UseThriftClient=1"; + + public static final String VALID_SPOG_URL_ENDPOINT = + "jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;" + + "httpPath=/sql/1.0/endpoints/abc123?o=6051921418418893;UseThriftClient=0"; + + public static final String VALID_SPOG_URL_WAREHOUSE_NO_EXTRA_PARAMS = + "jdbc:databricks://spog.cloud.databricks.com/default;ssl=1;AuthMech=3;" + + "httpPath=/sql/1.0/warehouses/abc123?o=6051921418418893"; } diff --git a/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java b/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java index 0aa31e366c..b1ec0cefcd 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java @@ -1357,4 +1357,83 @@ public void testOAuthWebServerTimeoutCustom() throws DatabricksSQLException { TestConstants.VALID_URL_1 + ";OAuthWebServerTimeout=300", properties); assertEquals(300, connectionContext.getOAuthWebServerTimeout()); } + + // ==================== SPOG ?o= Tests ==================== + + @Test + void testBuildPropertiesMap_preservesQueryParamInHttpPath() { + String params = "ssl=1;AuthMech=3;httpPath=/sql/1.0/warehouses/abc123?o=999;UseThriftClient=1"; + ImmutableMap result = buildPropertiesMap(params, new Properties()); + + assertEquals("/sql/1.0/warehouses/abc123?o=999", result.get("httppath")); + assertEquals("1", result.get("usethriftclient")); + } + + @Test + void testBuildPropertiesMap_handlesValueWithMultipleEquals() { + String params = "httpPath=/sql/1.0/warehouses/abc?o=999&other=foo"; + ImmutableMap result = buildPropertiesMap(params, new Properties()); + + assertEquals("/sql/1.0/warehouses/abc?o=999&other=foo", result.get("httppath")); + } + + @Test + void testBuildPropertiesMap_handlesValueWithNoEquals() { + String params = "keyonly"; + ImmutableMap result = buildPropertiesMap(params, new Properties()); + + assertEquals("", result.get("keyonly")); + } + + @Test + void testSpogContext_extractsOrgIdFromHttpPath() throws DatabricksSQLException { + Properties props = new Properties(); + props.put("user", "token"); + props.put("password", "test-token"); + IDatabricksConnectionContext ctx = + DatabricksConnectionContext.parse(TestConstants.VALID_SPOG_URL_WAREHOUSE, props); + + Map headers = ctx.getCustomHeaders(); + assertEquals("6051921418418893", headers.get("x-databricks-org-id")); + } + + @Test + void testSpogContext_extractsCleanWarehouseId() throws DatabricksSQLException { + Properties props = new Properties(); + props.put("user", "token"); + props.put("password", "test-token"); + IDatabricksConnectionContext ctx = + DatabricksConnectionContext.parse(TestConstants.VALID_SPOG_URL_WAREHOUSE, props); + + // Warehouse ID should be "abc123" not "abc123?o=6051921418418893" + assertTrue(ctx.getComputeResource() instanceof Warehouse); + assertEquals("abc123", ((Warehouse) ctx.getComputeResource()).getWarehouseId()); + } + + @Test + void testSpogContext_noOrgIdWithoutQueryParam() throws DatabricksSQLException { + Properties props = new Properties(); + props.put("user", "token"); + props.put("password", "test-token"); + IDatabricksConnectionContext ctx = + DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props); + + Map headers = ctx.getCustomHeaders(); + assertFalse(headers.containsKey("x-databricks-org-id")); + } + + @Test + void testSpogContext_explicitHeaderTakesPrecedence() throws DatabricksSQLException { + String url = + "jdbc:databricks://host/default;ssl=1;AuthMech=3;" + + "httpPath=/sql/1.0/warehouses/abc123?o=frompath;" + + "http.header.x-databricks-org-id=fromheader"; + Properties props = new Properties(); + props.put("user", "token"); + props.put("password", "test-token"); + IDatabricksConnectionContext ctx = DatabricksConnectionContext.parse(url, props); + + Map headers = ctx.getCustomHeaders(); + assertEquals("fromheader", headers.get("x-databricks-org-id")); + } } diff --git a/src/test/java/com/databricks/jdbc/common/util/ValidationUtilTest.java b/src/test/java/com/databricks/jdbc/common/util/ValidationUtilTest.java index 7695b5d11f..b49aa89a66 100644 --- a/src/test/java/com/databricks/jdbc/common/util/ValidationUtilTest.java +++ b/src/test/java/com/databricks/jdbc/common/util/ValidationUtilTest.java @@ -128,6 +128,13 @@ private static Stream jdbcUrlValidityTestCases() { "Valid URL with invalid compression type", true), Arguments.of(INVALID_URL_1, "Invalid non-Databricks JDBC URL", false), - Arguments.of(INVALID_URL_2, "Invalid malformed JDBC scheme", false)); + Arguments.of(INVALID_URL_2, "Invalid malformed JDBC scheme", false), + Arguments.of( + VALID_SPOG_URL_WAREHOUSE, "Valid SPOG URL with ?o= in warehouse httpPath", true), + Arguments.of(VALID_SPOG_URL_ENDPOINT, "Valid SPOG URL with ?o= in endpoint httpPath", true), + Arguments.of( + VALID_SPOG_URL_WAREHOUSE_NO_EXTRA_PARAMS, + "Valid SPOG URL with ?o= at end of URL", + true)); } }