Skip to content

Generalize CLI token source into progressive command list#1378

Draft
mihaimitrea-db wants to merge 2 commits intomainfrom
mihaimitrea-db/stack/cli-progressive-token-commands
Draft

Generalize CLI token source into progressive command list#1378
mihaimitrea-db wants to merge 2 commits intomainfrom
mihaimitrea-db/stack/cli-progressive-token-commands

Conversation

@mihaimitrea-db
Copy link
Copy Markdown
Contributor

@mihaimitrea-db mihaimitrea-db commented Mar 31, 2026

🥞 Stacked PR

Use this link to review incremental changes.


Summary

Generalize CliTokenSource from the explicit _force_cmd field and manual fallback override in DatabricksCliTokenSource.refresh() into a CliCommand dataclass and an optional commands list on CliTokenSource, with an _active_command_index that caches which command works so subsequent token fetches skip probing.

See: databricks/databricks-sdk-go#1605, databricks/databricks-sdk-java#752

Why

The parent PR (#1377) introduced --force-refresh support by adding a _force_cmd field to DatabricksCliTokenSource and overriding refresh() with hand-written fallback logic. This works, but every new flag would require adding another field, another try/except block, another error check, and another set of tests — the pattern doesn't scale.

We expect future flags like --scopes (forwarding custom OAuth scopes to the CLI). Rather than growing the class linearly with each flag, this PR extracts the repeating pattern into a loop over a command list.

Why try-and-retry over version detection or --help parsing

Three approaches were evaluated for resolving which flags the installed CLI supports:

  • Version detection (databricks version + static version table) was rejected because it creates a maintenance burden and a second source of truth. Every SDK (Go, Python, Java) would need to independently maintain a table mapping flags to the CLI version that introduced them. If any SDK's table falls out of sync with the CLI's actual releases, users silently get degraded commands.
  • --help flag parsing (databricks auth token --help + substring check) was rejected because it depends on the output format of --help — which is not a stable API. Cobra format changes could break detection, and naive substring matching is fragile.
  • Feature probing with try-and-retry (the approach taken here) uses the CLI itself as the authority on what it supports. Commands are built at init time from most-featured to simplest. On the first refresh() call, each command is tried in order; when the CLI responds with "unknown flag:", the next simpler command is tried. The working command index is cached so subsequent calls skip probing entirely. This approach has zero maintenance burden (no version numbers or flag registries to keep in sync), zero overhead on the happy path (newest CLI succeeds on the first command), and requires no signature changes. The key insight that makes this practical is that CLI flags are introduced incrementally — if the CLI doesn't support --profile, it certainly doesn't support --force-refresh.

What changed

Interface changes

None. CliTokenSource is not part of the public API surface.

Behavioral changes

None. The set of commands tried is identical to the parent PR. The only observable difference is that _active_command_index caches the working command, so subsequent refresh() calls execute that command directly without re-probing — a pure performance improvement. AzureCliTokenSource is completely unchanged; it does not pass commands and uses _refresh_single() which is an exact copy of the original refresh() logic.

Internal changes

  • CliCommand dataclass: replaces the _force_cmd field from the parent PR. Each entry holds args (the full CLI command), flags (used for error matching), and warning (logged when falling back from this command).
  • CliTokenSource.__init__: gains an optional commands: Optional[List[CliCommand]] parameter and initializes _active_command_index = -1 (unresolved).
  • CliTokenSource._is_unknown_flag_error: new static method that checks whether an IOError matches one of the given flags, preventing false-positive fallbacks.
  • CliTokenSource.refresh(): now a delegating method — calls _refresh_progressive() when commands is set, otherwise calls _refresh_single().
  • _refresh_single(): exact copy of the original CliTokenSource.refresh() logic (cmdfallback_cmd on "unknown flag: --profile"). Preserves full backward compatibility for AzureCliTokenSource and any other caller not using the progressive chain.
  • _refresh_progressive(): checks _active_command_index — if resolved (≥ 0), calls the cached command directly. Otherwise delegates to _probe_and_exec().
  • _probe_and_exec(): walks commands from index 0, falls back on unknown flag errors, stores _active_command_index on success.
  • DatabricksCliTokenSource.__init__: calls _build_commands() and passes the result to super().__init__(commands=...). Removes the _force_cmd field from the parent PR.
  • DatabricksCliTokenSource._build_commands(): static method that constructs the ordered List[CliCommand]--profile + --force-refresh first, plain --profile second, --host as a terminal fallback. --force-refresh is only paired with --profile, never with --host. Adding a future flag means adding one more CliCommand literal here.
  • DatabricksCliTokenSource.refresh(): now a one-liner that calls super().refresh() and validates token scopes.

How is this tested?

Unit tests in tests/test_credentials_provider.py:

  • TestDatabricksCliTokenSourceArgs:
    • test_profile_with_host_builds_three_commands — verifies 3 CliCommand entries with correct args for profile+host config.
    • test_profile_without_host_builds_two_commands — verifies 2 commands (no host fallback).
    • test_host_only_builds_one_command — verifies single --host command, no --force-refresh.
    • test_account_client_passes_account_id — verifies --account-id is appended for account hosts.
  • TestDatabricksCliForceRefresh:
    • test_force_refresh_tried_first_with_profile — force command succeeds, no further commands tried.
    • test_host_only_skips_force_refresh — host-only config does not attempt --force-refresh.
    • test_force_refresh_fallback_when_unsupported — force command fails with "unknown flag: --force-refresh", falls back to plain --profile.
    • test_profile_fallback_to_host — both force and profile fail, falls back to --host.
    • test_full_fallback_chain — all three commands tried, last one succeeds.
    • test_active_command_index_caching — first call probes and caches index; second call uses cached index directly (1 subprocess call instead of 2).
    • test_real_auth_error_does_not_trigger_fallback — non-flag error surfaces immediately.
    • test_is_unknown_flag_error — unit test for the static helper.

@mihaimitrea-db
Copy link
Copy Markdown
Contributor Author

Range-diff: stack/cli-force-refresh (bedfa97 -> f979a2c)
NEXT_CHANGELOG.md
@@ -0,0 +1,10 @@
+diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md
+--- a/NEXT_CHANGELOG.md
++++ b/NEXT_CHANGELOG.md
+ ### Documentation
+ 
+ ### Internal Changes
++* Generalize CLI token source into a progressive command list for forward-compatible flag support.
+ 
+ ### API Changes
+ * Add `disable_gov_tag_creation` field for `databricks.sdk.service.settings.RestrictWorkspaceAdminsMessage`.
\ No newline at end of file
databricks/sdk/credentials_provider.py
@@ -126,32 +126,40 @@
 +        commands: List[CliCommand] = []
 +        if cfg.profile:
 +            profile_args = [cli_path, "auth", "token", "--profile", cfg.profile]
-+            commands.append(CliCommand(
-+                args=profile_args + ["--force-refresh"],
-+                flags=["--force-refresh", "--profile"],
-+                warning="Databricks CLI does not support --force-refresh. "
-+                        "Please upgrade your CLI to the latest version.",
-+            ))
-+            commands.append(CliCommand(
-+                args=profile_args,
-+                flags=["--profile"],
-+                warning="Databricks CLI does not support --profile flag. "
-+                        "Falling back to --host. "
-+                        "Please upgrade your CLI to the latest version.",
-+            ))
++            commands.append(
++                CliCommand(
++                    args=profile_args + ["--force-refresh"],
++                    flags=["--force-refresh", "--profile"],
++                    warning="Databricks CLI does not support --force-refresh. "
++                    "Please upgrade your CLI to the latest version.",
++                )
++            )
++            commands.append(
++                CliCommand(
++                    args=profile_args,
++                    flags=["--profile"],
++                    warning="Databricks CLI does not support --profile flag. "
++                    "Falling back to --host. "
++                    "Please upgrade your CLI to the latest version.",
++                )
++            )
 +            if cfg.host:
-+                commands.append(CliCommand(
-+                    args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
++                commands.append(
++                    CliCommand(
++                        args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
++                        flags=[],
++                        warning="",
++                    )
++                )
++        else:
++            host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
++            commands.append(
++                CliCommand(
++                    args=host_args,
 +                    flags=[],
 +                    warning="",
-+                ))
-+        else:
-+            host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
-+            commands.append(CliCommand(
-+                args=host_args,
-+                flags=[],
-+                warning="",
-+            ))
++                )
++            )
 +        return commands
  
      def refresh(self) -> oauth.Token:
@@ -161,8 +169,7 @@
 -            flag = self._get_unsupported_flag(e)
 -            if flag in self._KNOWN_CLI_FLAGS:
 -                logger.warning(
--                    "Databricks CLI does not support %s. "
--                    "Please upgrade your CLI to the latest version.",
+-                    "Databricks CLI does not support %s. " "Please upgrade your CLI to the latest version.",
 -                    flag,
 -                )
 -                token = super().refresh()
tests/test_credentials_provider.py
@@ -69,16 +69,16 @@
 -
 -    def _make_process_error(self, stderr: str, stdout: str = ""):
 -        import subprocess
--
--        err = subprocess.CalledProcessError(1, ["databricks"])
--        err.stdout = stdout.encode()
--        err.stderr = stderr.encode()
--        return err
 +        commands = call_kwargs.kwargs["commands"]
 +        assert len(commands) == 2
 +        assert "--force-refresh" in commands[0].args
 +        assert "--force-refresh" not in commands[1].args
  
+-        err = subprocess.CalledProcessError(1, ["databricks"])
+-        err.stdout = stdout.encode()
+-        err.stderr = stderr.encode()
+-        return err
+-
 -    def test_fallback_on_unknown_profile_flag(self, mocker):
 -        """When --profile fails with 'unknown flag: --profile', falls back to --host command."""
 -        import json
@@ -217,13 +217,13 @@
          mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
          mock_run.side_effect = [
 -            # force_cmd: --profile + --force-refresh → unknown --profile
-+            # 1st: --profile + --force-refresh → unknown --profile
++            # 1st: --profile + --force-refresh -> unknown --profile
              self._make_process_error("Error: unknown flag: --profile"),
 -            # _refresh_without_force cmd: --profile → unknown --profile
-+            # 2nd: --profile → unknown --profile
++            # 2nd: --profile -> unknown --profile
              self._make_process_error("Error: unknown flag: --profile"),
 -            # _refresh_without_force fallback_cmd: --host → success
-+            # 3rd: --host (terminal) → success
++            # 3rd: --host (terminal) -> success
              Mock(stdout=self._valid_response_json("host-token").encode()),
          ]
  
@@ -239,13 +239,13 @@
          mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
          mock_run.side_effect = [
 -            # 1st: force_cmd (--profile + --force-refresh) → unknown --force-refresh
-+            # 1st: --profile + --force-refresh → unknown --force-refresh
++            # 1st: --profile + --force-refresh -> unknown --force-refresh
              self._make_process_error("Error: unknown flag: --force-refresh"),
 -            # 2nd: _refresh_without_force cmd (--profile) → unknown --profile
-+            # 2nd: --profile → unknown --profile
++            # 2nd: --profile -> unknown --profile
              self._make_process_error("Error: unknown flag: --profile"),
 -            # 3rd: _refresh_without_force fallback_cmd (--host) → success
-+            # 3rd: --host (terminal) → success
++            # 3rd: --host (terminal) -> success
              Mock(stdout=self._valid_response_json("plain").encode()),
          ]
  
@@ -275,7 +275,11 @@
 +
      def test_real_auth_error_does_not_trigger_fallback(self, mocker):
          """Real auth failures (not unknown-flag) surface immediately."""
-         ts = self._make_token_source()
+-        ts = self._make_token_source()
++        ts = self._make_token_source(profile="my-profile")
+ 
+         mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
+         mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host")
          assert "databricks OAuth is not configured" in str(exc_info.value)
          assert mock_run.call_count == 1
  

Reproduce locally: git range-diff 1d03f6d..bedfa97 9e82d18..f979a2c | Disable: git config gitstack.push-range-diff false

@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from f979a2c to b0cd110 Compare March 31, 2026 14:09
@mihaimitrea-db
Copy link
Copy Markdown
Contributor Author

Range-diff: stack/cli-force-refresh (f979a2c -> b0cd110)
NEXT_CHANGELOG.md
@@ -1,10 +1,10 @@
 diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md
 --- a/NEXT_CHANGELOG.md
 +++ b/NEXT_CHANGELOG.md
- ### Documentation
- 
  ### Internal Changes
+ * Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors.
+ * Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages.
 +* Generalize CLI token source into a progressive command list for forward-compatible flag support.
  
  ### API Changes
- * Add `disable_gov_tag_creation` field for `databricks.sdk.service.settings.RestrictWorkspaceAdminsMessage`.
\ No newline at end of file
+ * Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service.
\ No newline at end of file
tests/test_credentials_provider.py
@@ -11,9 +11,10 @@
          mock_init = mocker.patch.object(
              credentials_provider.CliTokenSource,
              "__init__",
+         mock_cfg = Mock()
          mock_cfg.profile = "my-profile"
          mock_cfg.host = "https://workspace.databricks.com"
-         mock_cfg.experimental_is_unified_host = False
+-
 +        mock_cfg.client_type = ClientType.WORKSPACE
 +        mock_cfg.account_id = None
          mock_cfg.databricks_cli_path = "/path/to/databricks"
@@ -53,7 +54,11 @@
 -        assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"]
 -        assert host_cmd is None
 -
--
++        commands = call_kwargs.kwargs["commands"]
++        assert len(commands) == 2
++        assert "--force-refresh" in commands[0].args
++        assert "--force-refresh" not in commands[1].args
+ 
 -# Tests for CliTokenSource fallback on unknown --profile flag
 -class TestCliTokenSourceFallback:
 -    """Tests that CliTokenSource falls back to --host when CLI doesn't support --profile."""
@@ -69,11 +74,7 @@
 -
 -    def _make_process_error(self, stderr: str, stdout: str = ""):
 -        import subprocess
-+        commands = call_kwargs.kwargs["commands"]
-+        assert len(commands) == 2
-+        assert "--force-refresh" in commands[0].args
-+        assert "--force-refresh" not in commands[1].args
- 
+-
 -        err = subprocess.CalledProcessError(1, ["databricks"])
 -        err.stdout = stdout.encode()
 -        err.stderr = stderr.encode()

Reproduce locally: git range-diff 9e82d18..f979a2c 3195f5d..b0cd110 | Disable: git config gitstack.push-range-diff false

@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from b0cd110 to bdc451f Compare March 31, 2026 15:12
@mihaimitrea-db
Copy link
Copy Markdown
Contributor Author

Range-diff: stack/cli-force-refresh (b0cd110 -> bdc451f)
databricks/sdk/credentials_provider.py
@@ -7,6 +7,13 @@
  import functools
  import io
  import json
+ import os
+ import pathlib
+ import platform
++import re
+ import subprocess
+ import sys
+ import threading
      return OAuthCredentialsProvider(refreshed_headers, token)
  
  
@@ -20,8 +27,10 @@
 +
 +
  class CliTokenSource(oauth.Refreshable):
-     _UNKNOWN_FLAG_RE = re.compile(r"unknown flag: (--[a-z-]+)")
++    _UNKNOWN_FLAG_RE = re.compile(r"unknown flag: (--[a-z-]+)")
  
+     def __init__(
+         self,
          access_token_field: str,
          expiry_field: str,
          disable_async: bool = True,
@@ -43,18 +52,15 @@
  
      @staticmethod
      def _parse_expiry(expiry: str) -> datetime:
+             message = "\n".join(filter(None, [stdout, stderr]))
              raise IOError(f"cannot get access token: {message}") from e
  
-     @staticmethod
--    def _get_unsupported_flag(error: IOError) -> Optional[str]:
--        """Extract the flag name if the error is an 'unknown flag' CLI rejection."""
--        match = CliTokenSource._UNKNOWN_FLAG_RE.search(str(error))
--        return match.group(1) if match else None
++    @staticmethod
 +    def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool:
 +        """Check if the error indicates the CLI rejected one of the given flags."""
 +        msg = str(error)
 +        return any(f"unknown flag: {flag}" in msg for flag in flags)
- 
++
      def refresh(self) -> oauth.Token:
 -        try:
 -            return self._exec_cli_command(self._cmd)
@@ -120,7 +126,15 @@
 +            commands=commands,
          )
  
--    _KNOWN_CLI_FLAGS = {"--force-refresh", "--profile"}
+-    def refresh(self) -> oauth.Token:
+-        try:
+-            token = self._exec_cli_command(self._force_cmd)
+-        except IOError as e:
+-            err_msg = str(e)
+-            if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
+-                logger.warning(
+-                    "Databricks CLI does not support --force-refresh. "
+-                    "Please upgrade your CLI to the latest version."
 +    @staticmethod
 +    def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]:
 +        commands: List[CliCommand] = []
@@ -132,7 +146,10 @@
 +                    flags=["--force-refresh", "--profile"],
 +                    warning="Databricks CLI does not support --force-refresh. "
 +                    "Please upgrade your CLI to the latest version.",
-+                )
+                 )
+-                token = super().refresh()
+-            else:
+-                raise
 +            )
 +            commands.append(
 +                CliCommand(
@@ -161,20 +178,8 @@
 +                )
 +            )
 +        return commands
- 
-     def refresh(self) -> oauth.Token:
--        try:
--            token = self._exec_cli_command(self._force_cmd)
--        except IOError as e:
--            flag = self._get_unsupported_flag(e)
--            if flag in self._KNOWN_CLI_FLAGS:
--                logger.warning(
--                    "Databricks CLI does not support %s. " "Please upgrade your CLI to the latest version.",
--                    flag,
--                )
--                token = super().refresh()
--            else:
--                raise
++
++    def refresh(self) -> oauth.Token:
 +        token = super().refresh()
          if self._requested_scopes:
              self._validate_token_scopes(token)
tests/test_credentials_provider.py
@@ -54,11 +54,7 @@
 -        assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"]
 -        assert host_cmd is None
 -
-+        commands = call_kwargs.kwargs["commands"]
-+        assert len(commands) == 2
-+        assert "--force-refresh" in commands[0].args
-+        assert "--force-refresh" not in commands[1].args
- 
+-
 -# Tests for CliTokenSource fallback on unknown --profile flag
 -class TestCliTokenSourceFallback:
 -    """Tests that CliTokenSource falls back to --host when CLI doesn't support --profile."""
@@ -103,7 +99,11 @@
 -    def test_fallback_triggered_when_unknown_flag_in_stderr_only(self, mocker):
 -        """Fallback triggers even when CLI also writes usage text to stdout."""
 -        import json
--
++        commands = call_kwargs.kwargs["commands"]
++        assert len(commands) == 2
++        assert "--force-refresh" in commands[0].args
++        assert "--force-refresh" not in commands[1].args
+ 
 -        expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
 -        valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry})
 -
@@ -284,12 +284,6 @@
          assert "databricks OAuth is not configured" in str(exc_info.value)
          assert mock_run.call_count == 1
  
--    def test_get_unsupported_flag_extracts_flag(self):
--        """The classifier correctly parses the flag name from CLI error output."""
--        get = credentials_provider.CliTokenSource._get_unsupported_flag
--        assert get(IOError("Error: unknown flag: --force-refresh")) == "--force-refresh"
--        assert get(IOError("Error: unknown flag: --profile")) == "--profile"
--        assert get(IOError("some other error")) is None
 +    def test_is_unknown_flag_error(self):
 +        """_is_unknown_flag_error matches against specific flag list."""
 +        check = credentials_provider.CliTokenSource._is_unknown_flag_error
@@ -297,6 +291,7 @@
 +        assert check(IOError("Error: unknown flag: --profile"), ["--profile"])
 +        assert not check(IOError("Error: unknown flag: --force-refresh"), ["--profile"])
 +        assert not check(IOError("some other error"), ["--force-refresh"])
- 
++
  
- # Tests for cloud-agnostic hosts and removed cloud checks
\ No newline at end of file
+ # Tests for cloud-agnostic hosts and removed cloud checks
+ class TestCloudAgnosticHosts:
\ No newline at end of file

Reproduce locally: git range-diff 3195f5d..b0cd110 4115f37..bdc451f | Disable: git config gitstack.push-range-diff false

@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from bdc451f to 8821700 Compare March 31, 2026 15:33
@mihaimitrea-db
Copy link
Copy Markdown
Contributor Author

Range-diff: stack/cli-force-refresh (bdc451f -> 8821700)
databricks/sdk/credentials_provider.py
@@ -7,16 +7,10 @@
  import functools
  import io
  import json
- import os
- import pathlib
- import platform
-+import re
- import subprocess
- import sys
- import threading
      return OAuthCredentialsProvider(refreshed_headers, token)
  
  
+-class CliTokenSource(oauth.Refreshable):
 +@dataclasses.dataclass
 +class CliCommand:
 +    """A single CLI command variant with metadata for progressive fallback."""
@@ -24,26 +18,19 @@
 +    args: List[str]
 +    flags: List[str]
 +    warning: str
-+
+ 
 +
- class CliTokenSource(oauth.Refreshable):
-+    _UNKNOWN_FLAG_RE = re.compile(r"unknown flag: (--[a-z-]+)")
- 
++class CliTokenSource(oauth.Refreshable):
      def __init__(
          self,
-         access_token_field: str,
+         cmd: List[str],
          expiry_field: str,
          disable_async: bool = True,
--        fallback_cmd: Optional[List[str]] = None,
+         fallback_cmd: Optional[List[str]] = None,
 +        commands: Optional[List[CliCommand]] = None,
      ):
          super().__init__(disable_async=disable_async)
          self._cmd = cmd
--        # fallback_cmd is tried when the primary command fails with "unknown flag: --profile",
--        # indicating the CLI is too old to support --profile. Can be removed once support
--        # for CLI versions predating --profile is dropped.
--        # See: https://github.com/databricks/databricks-sdk-go/pull/1497
--        self._fallback_cmd = fallback_cmd
          self._token_type_field = token_type_field
          self._access_token_field = access_token_field
          self._expiry_field = expiry_field
@@ -62,23 +49,17 @@
 +        return any(f"unknown flag: {flag}" in msg for flag in flags)
 +
      def refresh(self) -> oauth.Token:
--        try:
--            return self._exec_cli_command(self._cmd)
--        except IOError as e:
--            if self._fallback_cmd is not None and "unknown flag: --profile" in str(e):
--                logger.warning(
--                    "Databricks CLI does not support --profile flag. Falling back to --host. "
--                    "Please upgrade your CLI to the latest version."
--                )
--                return self._exec_cli_command(self._fallback_cmd)
--            raise
 +        if self._commands is not None:
 +            return self._refresh_progressive()
 +        return self._refresh_single()
 +
 +    def _refresh_single(self) -> oauth.Token:
-+        return self._exec_cli_command(self._cmd)
-+
+         try:
+             return self._exec_cli_command(self._cmd)
+         except IOError as e:
+                 return self._exec_cli_command(self._fallback_cmd)
+             raise
+ 
 +    def _refresh_progressive(self) -> oauth.Token:
 +        last_err: Optional[IOError] = None
 +        for i in range(self._active_command_index, len(self._commands)):
@@ -94,9 +75,10 @@
 +                logger.warning(cmd.warning)
 +                last_err = e
 +        raise last_err
- 
++
  
  def _run_subprocess(
+     popenargs,
          elif cli_path.count("/") == 0:
              cli_path = self.__class__._find_executable(cli_path)
  

Reproduce locally: git range-diff 4115f37..bdc451f 4115f37..8821700 | Disable: git config gitstack.push-range-diff false

@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from 8821700 to 336b9be Compare March 31, 2026 16:28
@mihaimitrea-db
Copy link
Copy Markdown
Contributor Author

Range-diff: stack/cli-force-refresh (8821700 -> 336b9be)
databricks/sdk/credentials_provider.py
@@ -10,7 +10,6 @@
      return OAuthCredentialsProvider(refreshed_headers, token)
  
  
--class CliTokenSource(oauth.Refreshable):
 +@dataclasses.dataclass
 +class CliCommand:
 +    """A single CLI command variant with metadata for progressive fallback."""
@@ -18,12 +17,11 @@
 +    args: List[str]
 +    flags: List[str]
 +    warning: str
++
++
+ class CliTokenSource(oauth.Refreshable):
  
-+
-+class CliTokenSource(oauth.Refreshable):
      def __init__(
-         self,
-         cmd: List[str],
          expiry_field: str,
          disable_async: bool = True,
          fallback_cmd: Optional[List[str]] = None,
tests/test_credentials_provider.py
@@ -22,7 +22,7 @@
  
          credentials_provider.DatabricksCliTokenSource(mock_cfg)
  
-         call_kwargs = mock_init.call_args
+-        call_kwargs = mock_init.call_args
 -        cmd = call_kwargs.kwargs["cmd"]
 -        host_cmd = call_kwargs.kwargs["fallback_cmd"]
 -
@@ -31,23 +31,23 @@
 -        assert "--host" in host_cmd
 -        assert "https://workspace.databricks.com" in host_cmd
 -        assert "--profile" not in host_cmd
--
--    def test_profile_without_host_no_fallback(self, mocker):
--        """When profile is set but host is absent, no fallback is built."""
-+        commands = call_kwargs.kwargs["commands"]
++        commands = mock_init.call_args.kwargs["commands"]
 +        assert len(commands) == 3
 +        assert "--force-refresh" in commands[0].args and "--profile" in commands[0].args
 +        assert "--profile" in commands[1].args and "--force-refresh" not in commands[1].args
 +        assert "--host" in commands[2].args and "--force-refresh" not in commands[2].args
-+
+ 
+-    def test_profile_without_host_no_fallback(self, mocker):
+-        """When profile is set but host is absent, no fallback is built."""
 +    def test_profile_without_host_builds_two_commands(self, mocker):
 +        """With profile only: force-refresh and plain profile."""
          mock_init = mocker.patch.object(
              credentials_provider.CliTokenSource,
              "__init__",
+ 
          credentials_provider.DatabricksCliTokenSource(mock_cfg)
  
-         call_kwargs = mock_init.call_args
+-        call_kwargs = mock_init.call_args
 -        cmd = call_kwargs.kwargs["cmd"]
 -        host_cmd = call_kwargs.kwargs["fallback_cmd"]
 -
@@ -99,14 +99,14 @@
 -    def test_fallback_triggered_when_unknown_flag_in_stderr_only(self, mocker):
 -        """Fallback triggers even when CLI also writes usage text to stdout."""
 -        import json
-+        commands = call_kwargs.kwargs["commands"]
+-
+-        expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
+-        valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry})
++        commands = mock_init.call_args.kwargs["commands"]
 +        assert len(commands) == 2
 +        assert "--force-refresh" in commands[0].args
 +        assert "--force-refresh" not in commands[1].args
  
--        expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
--        valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry})
--
 -        mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
 -        mock_run.side_effect = [
 -            self._make_process_error(stderr="Error: unknown flag: --profile", stdout="Usage: databricks auth token"),
@@ -123,7 +123,7 @@
 -        mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
 -        mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host")
 +    def test_host_only_builds_one_command(self, mocker):
-+        """With host only: single plain host command, no progressive fallback."""
++        """With host only: single plain host command."""
 +        mock_init = mocker.patch.object(
 +            credentials_provider.CliTokenSource,
 +            "__init__",
@@ -155,8 +155,7 @@
 -            ts.refresh()
 -        assert "unknown flag: --profile" in str(exc_info.value)
 -        assert mock_run.call_count == 1
-+        call_kwargs = mock_init.call_args
-+        commands = call_kwargs.kwargs["commands"]
++        commands = mock_init.call_args.kwargs["commands"]
 +        assert len(commands) == 1
 +        assert "--host" in commands[0].args
 +        assert "--force-refresh" not in commands[0].args
@@ -204,10 +203,8 @@
          mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
          mock_run.side_effect = [
              self._make_process_error("Error: unknown flag: --force-refresh"),
-         second_cmd = mock_run.call_args_list[1][0][0]
          assert "--force-refresh" in first_cmd
          assert "--force-refresh" not in second_cmd
-+        assert "--profile" in second_cmd
  
 -    def test_profile_fallback_when_unsupported(self, mocker):
 -        """Old CLI without --profile: force_cmd fails, fallback retries with --host."""
@@ -259,10 +256,8 @@
 +
 +        mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
 +        mock_run.side_effect = [
-+            # First refresh: force-refresh fails, plain profile succeeds
 +            self._make_process_error("Error: unknown flag: --force-refresh"),
 +            Mock(stdout=self._valid_response_json("first").encode()),
-+            # Second refresh: starts at cached index (plain profile), succeeds immediately
 +            Mock(stdout=self._valid_response_json("second").encode()),
 +        ]
 +

Reproduce locally: git range-diff 4115f37..8821700 4115f37..336b9be | Disable: git config gitstack.push-range-diff false

@mihaimitrea-db mihaimitrea-db self-assigned this Apr 1, 2026
@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from 336b9be to 96fcf7f Compare April 1, 2026 08:44
@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from 96fcf7f to b9159e8 Compare April 1, 2026 08:57
@mihaimitrea-db
Copy link
Copy Markdown
Contributor Author

Range-diff: stack/cli-force-refresh (96fcf7f -> b9159e8)
databricks/sdk/credentials_provider.py
@@ -81,14 +81,14 @@
              cli_path = self.__class__._find_executable(cli_path)
  
 -        fallback_cmd = None
+-        self._force_cmd = None
 -        if cfg.profile:
 -            args = ["auth", "token", "--profile", cfg.profile]
+-            self._force_cmd = [cli_path, *args, "--force-refresh"]
 -            if cfg.host:
 -                fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
 -        else:
 -            args = self.__class__._build_host_args(cfg)
--
--        self._force_cmd = [cli_path, *args, "--force-refresh"]
 +        commands = self.__class__._build_commands(cli_path, cfg)
  
          # get_scopes() defaults to ["all-apis"] when nothing is configured, which would
@@ -107,14 +107,17 @@
          )
  
 -    def refresh(self) -> oauth.Token:
--        try:
--            token = self._exec_cli_command(self._force_cmd)
--        except IOError as e:
--            err_msg = str(e)
--            if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
--                logger.warning(
--                    "Databricks CLI does not support --force-refresh. "
--                    "Please upgrade your CLI to the latest version."
+-        if self._force_cmd is None:
+-            token = super().refresh()
+-        else:
+-            try:
+-                token = self._exec_cli_command(self._force_cmd)
+-            except IOError as e:
+-                err_msg = str(e)
+-                if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
+-                    logger.warning(
+-                        "Databricks CLI does not support --force-refresh. "
+-                        "Please upgrade your CLI to the latest version."
 +    @staticmethod
 +    def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]:
 +        commands: List[CliCommand] = []
@@ -126,10 +129,7 @@
 +                    flags=["--force-refresh", "--profile"],
 +                    warning="Databricks CLI does not support --force-refresh. "
 +                    "Please upgrade your CLI to the latest version.",
-                 )
--                token = super().refresh()
--            else:
--                raise
++                )
 +            )
 +            commands.append(
 +                CliCommand(
@@ -146,7 +146,10 @@
 +                        args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
 +                        flags=[],
 +                        warning="",
-+                    )
+                     )
+-                    token = super().refresh()
+-                else:
+-                    raise
 +                )
 +        else:
 +            host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
tests/test_credentials_provider.py
@@ -102,11 +102,7 @@
 -
 -        expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
 -        valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry})
-+        commands = mock_init.call_args.kwargs["commands"]
-+        assert len(commands) == 2
-+        assert "--force-refresh" in commands[0].args
-+        assert "--force-refresh" not in commands[1].args
- 
+-
 -        mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
 -        mock_run.side_effect = [
 -            self._make_process_error(stderr="Error: unknown flag: --profile", stdout="Usage: databricks auth token"),
@@ -117,7 +113,11 @@
 -        ts = self._make_token_source(fallback_cmd=fallback_cmd)
 -        token = ts.refresh()
 -        assert token.access_token == "fallback-token"
--
++        commands = mock_init.call_args.kwargs["commands"]
++        assert len(commands) == 2
++        assert "--force-refresh" in commands[0].args
++        assert "--force-refresh" not in commands[1].args
+ 
 -    def test_no_fallback_on_real_auth_error(self, mocker):
 -        """When --profile fails with a real error (not unknown flag), no fallback is attempted."""
 -        mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
@@ -162,49 +162,18 @@
  
  
  class TestDatabricksCliForceRefresh:
-         expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
-         return json.dumps({"access_token": access_token, "token_type": "Bearer", "expiry": expiry})
- 
--    def test_force_refresh_always_tried_first(self, mocker):
--        """refresh() always tries --force-refresh first."""
--        ts = self._make_token_source()
-+    def test_force_refresh_tried_first_with_profile(self, mocker):
-+        """When profile is configured, refresh() tries --force-refresh first."""
-+        ts = self._make_token_source(profile="my-profile")
- 
-         mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
-         mock_run.return_value = Mock(stdout=self._valid_response_json("refreshed").encode())
- 
-         cmd = mock_run.call_args[0][0]
-         assert "--force-refresh" in cmd
-+        assert "--profile" in cmd
+         assert "--host" in cmd
  
--    def test_force_refresh_fallback_when_unsupported(self, mocker):
+     def test_force_refresh_fallback_when_unsupported(self, mocker):
 -        """Old CLI without --force-refresh: falls back to cmd without --force-refresh."""
-+    def test_host_only_no_force_refresh(self, mocker):
-+        """When only host is configured, --force-refresh is not used."""
-         ts = self._make_token_source()
++        """Old CLI without --force-refresh: falls back to plain --profile command."""
+         ts = self._make_token_source(profile="my-profile")
  
-+        mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
-+        mock_run.return_value = Mock(stdout=self._valid_response_json("token").encode())
-+
-+        token = ts.refresh()
-+        assert token.access_token == "token"
-+        assert mock_run.call_count == 1
-+
-+        cmd = mock_run.call_args[0][0]
-+        assert "--force-refresh" not in cmd
-+        assert "--host" in cmd
-+
-+    def test_force_refresh_fallback_when_unsupported(self, mocker):
-+        """Old CLI without --force-refresh: falls back to plain --profile command."""
-+        ts = self._make_token_source(profile="my-profile")
-+
          mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
-         mock_run.side_effect = [
-             self._make_process_error("Error: unknown flag: --force-refresh"),
+         second_cmd = mock_run.call_args_list[1][0][0]
          assert "--force-refresh" in first_cmd
          assert "--force-refresh" not in second_cmd
++        assert "--profile" in second_cmd
  
 -    def test_profile_fallback_when_unsupported(self, mocker):
 -        """Old CLI without --profile: force_cmd fails, fallback retries with --host."""
@@ -215,13 +184,10 @@
          mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
          mock_run.side_effect = [
 -            # force_cmd: --profile + --force-refresh → unknown --profile
-+            # 1st: --profile + --force-refresh -> unknown --profile
              self._make_process_error("Error: unknown flag: --profile"),
 -            # _refresh_without_force cmd: --profile → unknown --profile
-+            # 2nd: --profile -> unknown --profile
              self._make_process_error("Error: unknown flag: --profile"),
 -            # _refresh_without_force fallback_cmd: --host → success
-+            # 3rd: --host (terminal) -> success
              Mock(stdout=self._valid_response_json("host-token").encode()),
          ]
  
@@ -237,13 +203,10 @@
          mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
          mock_run.side_effect = [
 -            # 1st: force_cmd (--profile + --force-refresh) → unknown --force-refresh
-+            # 1st: --profile + --force-refresh -> unknown --force-refresh
              self._make_process_error("Error: unknown flag: --force-refresh"),
 -            # 2nd: _refresh_without_force cmd (--profile) → unknown --profile
-+            # 2nd: --profile -> unknown --profile
              self._make_process_error("Error: unknown flag: --profile"),
 -            # 3rd: _refresh_without_force fallback_cmd (--host) → success
-+            # 3rd: --host (terminal) -> success
              Mock(stdout=self._valid_response_json("plain").encode()),
          ]
  

Reproduce locally: git range-diff 469cb44..96fcf7f 32c2cbd..b9159e8 | Disable: git config gitstack.push-range-diff false

@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from b9159e8 to f8fff86 Compare April 1, 2026 09:00
When the SDK's cached CLI token is stale, try `databricks auth token
--force-refresh` to get a freshly minted token from the IdP. If the
installed CLI is too old to recognise the flag, fall back to regular
`auth token` and remember the capability for future refreshes.

Centralise unknown-flag detection in CliTokenSource._exec_cli_command()
via UnsupportedCliFlagError so the same classifier is reused by both the
legacy --profile fallback and the new --force-refresh downgrade path in
DatabricksCliTokenSource.

See: databricks/cli#4767
@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from f8fff86 to 29d8b73 Compare April 1, 2026 09:09
Replace the explicit force_cmd/fallback_cmd fields with a CliCommand
dataclass and an optional commands list on CliTokenSource.

When commands is provided, refresh() delegates to _refresh_progressive()
which walks the list from activeCommandIndex, falling back on
unsupported-flag errors. When commands is None, refresh() delegates to
_refresh_single() which preserves the original fallback behavior
with zero changes for AzureCliTokenSource.

DatabricksCliTokenSource._build_commands() produces the progressive
list: --profile + --force-refresh first, plain --profile second, and
--host as a terminal fallback. --force-refresh is only paired with
--profile, never with --host. Adding future flags (e.g. --scopes)
requires only adding entries to _build_commands().
@mihaimitrea-db mihaimitrea-db force-pushed the mihaimitrea-db/stack/cli-progressive-token-commands branch from 29d8b73 to 1432f4c Compare April 1, 2026 09:24
@mihaimitrea-db
Copy link
Copy Markdown
Contributor Author

Range-diff: stack/cli-force-refresh (29d8b73 -> 1432f4c)
databricks/sdk/credentials_provider.py
@@ -33,7 +33,7 @@
          self._access_token_field = access_token_field
          self._expiry_field = expiry_field
 +        self._commands = commands
-+        self._active_command_index = 0
++        self._active_command_index = -1
  
      @staticmethod
      def _parse_expiry(expiry: str) -> datetime:
@@ -59,9 +59,19 @@
              raise
  
 +    def _refresh_progressive(self) -> oauth.Token:
-+        last_err: Optional[IOError] = None
-+        for i in range(self._active_command_index, len(self._commands)):
-+            cmd = self._commands[i]
++        idx = self._active_command_index
++        if idx >= 0:
++            return self._exec_cli_command(self._commands[idx].args)
++        return self._probe_and_exec()
++
++    def _probe_and_exec(self) -> oauth.Token:
++        """Walk the command list to find a CLI command that succeeds.
++
++        When a command fails with "unknown flag" for one of its flags, log a
++        warning and try the next.  On success, store _active_command_index so
++        future calls skip probing.
++        """
++        for i, cmd in enumerate(self._commands):
 +            try:
 +                token = self._exec_cli_command(cmd.args)
 +                self._active_command_index = i
@@ -71,8 +81,6 @@
 +                if is_last or not self._is_unknown_flag_error(e, cmd.flags):
 +                    raise
 +                logger.warning(cmd.warning)
-+                last_err = e
-+        raise last_err
 +
  
  def _run_subprocess(

Reproduce locally: git range-diff cd6c876..29d8b73 cd6c876..1432f4c | Disable: git config gitstack.push-range-diff false

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 1, 2026

If integration tests don't run automatically, an authorized user can run them manually by following the instructions below:

Trigger:
go/deco-tests-run/sdk-py

Inputs:

  • PR number: 1378
  • Commit SHA: 1432f4c5348be9f69876c60578593d15e6510333

Checks will be approved automatically on success.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant