Skip to content

Commit 2afaab5

Browse files
Add custom scopes support in OAuth authentication methods (#1374)
## Summary + Adds new `Scopes` and `DisableOAuthRefreshToken` fields. + Adds support for parsing lists in config files. ## Notes - Environment variable support is **not** provided for scopes because we do not think users would actually want this. It can be easily added as a follow up if there are requests for it. ## Testing - Parsing logic tested by loading profiles from a test `.databrickscfg` file. Subsequent PRs add support for custom scopes in OAuth authentication methods: - M2M: #1388 - WIF: #1389 - U2M: #1390 <!-- This PR adds support for user-provided OAuth scopes across all OAuth authentication flows (M2M, U2M, and WIF/OIDC). Users can now request fine-grained permissions instead of the default `all-apis` scope. ## Review Guide 1. **Start with `config/config.go`** - Review the new `Scopes` and `DisableOAuthRefreshToken` fields and `GetScopes()` method 2. **Review each auth flow integration**: - `config/auth_m2m.go` (one-line change) - `config/auth_u2m.go` (passes scopes to PersistentAuth, adds `persistentAuthFactory` for testability) - `config/auth_default.go` (passes scopes to OIDC token source) - `credentials/u2m/persistent_auth.go` (new options, `offline_access` handling, `GetScopes()` for test introspection) - `config/experimental/auth/oidc/tokensource.go` (accepts scopes in config) 3. **Review `config/config_attribute.go`** - adds slice type support for config file parsing 4. **Review tests** - verify scope assertions match expected behavior and look for missing test cases. #### Backwards Compatibility - All three OAuth flows continue to use `all-apis` as the default scope. - U2M continues to append `offline_access` scope by default. ## Testing #### Shared Config Layer - **`TestConfigFile_Scopes`** - Loads profiles from `.databrickscfg`; calls `cfg.EnsureResolved()`; asserts `cfg.GetScopes()` returns correctly parsed and sorted values. #### M2M Flow - **`TestM2M_Scopes`** - Sets up mock HTTP transport expecting specific `scope` values; calls `Config.Authenticate()`; asserts the token request contains expected scopes. #### U2M Flow Tests are split across two files to test different responsibilities: **`config/auth_u2m_test.go`** - Tests scope propagation from Config to PersistentAuth: - **`TestU2MCredentials_Configure_DefaultScopes`** - Uses a capturing factory that creates a real `PersistentAuth` and spies on it; calls `u2mCredentials.Configure()` with nil scopes; asserts `PersistentAuth.GetScopes()` returns `["all-apis"]`. - **`TestU2MCredentials_Configure_CustomScopes`** - Same setup; calls `Configure()` with custom scopes; asserts they are passed through correctly. **`credentials/u2m/persistent_auth_test.go`** - Tests `offline_access` handling: - **`TestU2M_ScopesAndOfflineAccess`** - Sets up mock browser capturing the authorization URL; calls `PersistentAuth.Challenge()`; asserts the `scope` query parameter contains expected scopes with `offline_access` appended (or omitted when `disableOfflineAccess` is true). ### WIF/OIDC Flow - **`TestWIF_Scopes`** - Sets up mock HTTP transport expecting specific `scope` values; calls `TokenSource.Token()`; asserts the token exchange request contains expected scopes. - **`TestGithubOIDC_Scopes`** - Sets up mock HTTP transport for GitHub and Databricks endpoints; calls `Config.Authenticate()`; asserts scopes flow correctly through to the token exchange request. --> --- NO_CHANGELOG=true --------- Co-authored-by: Renaud Hartert <[email protected]>
1 parent e146526 commit 2afaab5

File tree

9 files changed

+301
-10
lines changed

9 files changed

+301
-10
lines changed

config/config.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/databricks/databricks-sdk-go/credentials/u2m"
1919
"github.com/databricks/databricks-sdk-go/httpclient"
2020
"github.com/databricks/databricks-sdk-go/logger"
21+
"golang.org/x/exp/slices"
2122
"golang.org/x/oauth2"
2223
)
2324

@@ -63,6 +64,8 @@ const (
6364
InvalidConfig ConfigType = "INVALID_CONFIG"
6465
)
6566

67+
var defaultScopes = []string{"all-apis"}
68+
6669
// Config represents configuration for Databricks Connectivity
6770
type Config struct {
6871
// Credentials holds an instance of Credentials Strategy to authenticate with Databricks REST APIs.
@@ -135,6 +138,26 @@ type Config struct {
135138
ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth" auth_types:"oauth-m2m"`
136139
ClientSecret string `name:"client_secret" env:"DATABRICKS_CLIENT_SECRET" auth:"oauth,sensitive" auth_types:"oauth-m2m"`
137140

141+
// Scopes is a list of OAuth scopes to request when authenticating.
142+
//
143+
// WARNING:
144+
// - This feature is still in development and may not work as expected
145+
// - This feature is EXPERIMENTAL and may change or be removed without notice.
146+
// - Do NOT use this feature in production environments.
147+
//
148+
// Notes:
149+
// - If Scopes is nil or empty, the default ["all-apis"] scope will be used for backward compatibility.
150+
// - For U2M authentication, the "offline_access" scope will automatically be added to obtain a refresh token
151+
// unless you set DisableOAuthRefreshToken to true.
152+
// - You cannot set Scopes via environment variables.
153+
// - The scopes list will be sorted in-place during configuration resolution.
154+
// - The U2M token cache currently does NOT support differentiated caching for scopes.
155+
Scopes []string `name:"scopes" auth:"-"`
156+
157+
// DisableOAuthRefreshToken controls whether a refresh token should be requested
158+
// during the U2M authentication flow (default to false).
159+
DisableOAuthRefreshToken bool `name:"disable_oauth_refresh_token" env:"DATABRICKS_DISABLE_OAUTH_REFRESH_TOKEN" auth:"-"`
160+
138161
// Path to the Databricks CLI (version >= 0.100.0).
139162
DatabricksCliPath string `name:"databricks_cli_path" env:"DATABRICKS_CLI_PATH" auth_types:"databricks-cli"`
140163

@@ -445,6 +468,8 @@ func (c *Config) EnsureResolved() error {
445468
},
446469
}
447470
}
471+
slices.Sort(c.Scopes)
472+
c.Scopes = slices.Compact(c.Scopes)
448473
c.resolved = true
449474
return nil
450475
}
@@ -460,6 +485,13 @@ func (c *Config) CanonicalHostName() string {
460485
return c.Host
461486
}
462487

488+
func (c *Config) GetScopes() []string {
489+
if len(c.Scopes) == 0 {
490+
return defaultScopes
491+
}
492+
return c.Scopes
493+
}
494+
463495
func (c *Config) wrapDebug(err error) error {
464496
debug := ConfigAttributes.DebugString(c)
465497
if debug == "" {

config/config_attribute.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,17 @@ import (
55
"os"
66
"reflect"
77
"strconv"
8+
"strings"
89
)
910

11+
// getenv is the function used to read environment variables.
12+
// It defaults to os.Getenv but can be overwritten in tests.
13+
var getenv = os.Getenv
14+
15+
// getUserHomeDir is the function used to get user home directory.
16+
// It defaults to os.UserHomeDir but can be overwritten in tests.
17+
var getUserHomeDir = os.UserHomeDir
18+
1019
type Source struct {
1120
Type SourceType `json:"type"`
1221
Name string `json:"name,omitempty"`
@@ -44,7 +53,7 @@ type ConfigAttribute struct {
4453

4554
func (a *ConfigAttribute) ReadEnv() (string, string) {
4655
for _, envName := range a.EnvVars {
47-
v := os.Getenv(envName)
56+
v := getenv(envName)
4857
if v == "" {
4958
continue
5059
}
@@ -69,6 +78,16 @@ func (a *ConfigAttribute) SetS(cfg *Config, v string) error {
6978
return err
7079
}
7180
return a.Set(cfg, vv)
81+
case reflect.Slice:
82+
rawParts := strings.Split(v, ",")
83+
var parts []string
84+
for _, part := range rawParts {
85+
trimmed := strings.TrimSpace(part)
86+
if trimmed != "" {
87+
parts = append(parts, trimmed)
88+
}
89+
}
90+
return a.Set(cfg, parts)
7291
default:
7392
return fmt.Errorf("cannot set %s of unknown type %s",
7493
a.Name, reflectKind(a.Kind))
@@ -85,6 +104,8 @@ func (a *ConfigAttribute) Set(cfg *Config, i interface{}) error {
85104
field.SetBool(i.(bool))
86105
case reflect.Int:
87106
field.SetInt(int64(i.(int)))
107+
case reflect.Slice:
108+
field.Set(reflect.ValueOf(i.([]string)))
88109
default:
89110
// must extensively test with providerFixture to avoid this one
90111
return fmt.Errorf("cannot set %s of unknown type %s", a.Name, reflectKind(a.Kind))

config/config_attributes.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package config
22

33
import (
44
"fmt"
5-
"os"
65
"reflect"
76
"sort"
87
"strings"
@@ -31,7 +30,7 @@ func (a attributes) DebugString(cfg *Config) string {
3130
}
3231
attrsUsed = append(attrsUsed, fmt.Sprintf("%s=%s", attr.Name, v))
3332
for _, envName := range attr.EnvVars {
34-
v := os.Getenv(envName)
33+
v := getenv(envName)
3534
if v == "" {
3635
continue
3736
}

config/config_attributes_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package config
2+
3+
import (
4+
"testing"
5+
6+
"github.com/google/go-cmp/cmp"
7+
)
8+
9+
// TestConfigFile_Configure_ListParsing tests that comma-separated list values
10+
// in configuration files are correctly parsed into slices.
11+
func TestConfigFile_Configure_ListParsing(t *testing.T) {
12+
testCases := []struct {
13+
name string
14+
profile string
15+
want []string
16+
}{
17+
{
18+
name: "single item",
19+
profile: "single-item",
20+
want: []string{"clusters"},
21+
},
22+
{
23+
name: "multiple items",
24+
profile: "multiple-items",
25+
want: []string{"alpha", "beta", "gamma"},
26+
},
27+
{
28+
name: "whitespace around items is trimmed",
29+
profile: "whitespace-around-items",
30+
want: []string{"alpha", "beta", "gamma"},
31+
},
32+
{
33+
name: "empty items are filtered out",
34+
profile: "empty-items-filtered",
35+
want: []string{"alpha", "beta"},
36+
},
37+
{
38+
name: "whitespace-only items are filtered out",
39+
profile: "whitespace-only-items-filtered",
40+
want: []string{"alpha", "beta"},
41+
},
42+
}
43+
44+
for _, tc := range testCases {
45+
t.Run(tc.name, func(t *testing.T) {
46+
withMockEnv(t, map[string]string{})
47+
48+
cfg := &Config{
49+
Profile: tc.profile,
50+
ConfigFile: "testdata/list-parsing/.databrickscfg",
51+
}
52+
err := ConfigFile.Configure(cfg)
53+
if err != nil {
54+
t.Fatalf("Configure failed: %v", err)
55+
}
56+
if diff := cmp.Diff(tc.want, cfg.Scopes); diff != "" {
57+
t.Errorf("list mismatch (-want +got):\n%s", diff)
58+
}
59+
})
60+
}
61+
}

config/config_file.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func LoadFile(path string) (*File, error) {
3838

3939
// Expand ~ to home directory.
4040
if strings.HasPrefix(path, "~") {
41-
homedir, err := os.UserHomeDir()
41+
homedir, err := getUserHomeDir()
4242
if err != nil {
4343
return nil, fmt.Errorf("cannot find homedir: %w", err)
4444
}

config/config_file_test.go

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,93 @@ package config
33
import (
44
"testing"
55

6-
"github.com/stretchr/testify/assert"
7-
"github.com/stretchr/testify/require"
6+
"github.com/google/go-cmp/cmp"
87
)
98

9+
// withMockEnv mocks environment variables for testing config file loading
10+
// without relying on the actual system environment or filesystem.
11+
// getUserHomeDir falls back to the real implementation when HOME is not in
12+
// the env map, allowing tests to optionally override HOME without breaking
13+
// tests that don't need to control the home directory path.
14+
func withMockEnv(t *testing.T, env map[string]string) {
15+
original := getenv
16+
originalUserHomeDir := getUserHomeDir
17+
t.Cleanup(func() {
18+
getenv = original
19+
getUserHomeDir = originalUserHomeDir
20+
})
21+
getenv = func(key string) string {
22+
return env[key]
23+
}
24+
getUserHomeDir = func() (string, error) {
25+
if home, ok := env["HOME"]; ok {
26+
return home, nil
27+
}
28+
return originalUserHomeDir()
29+
}
30+
}
31+
1032
func TestConfigFileLoad(t *testing.T) {
1133
f, err := LoadFile("testdata/.databrickscfg")
12-
require.NoError(t, err)
13-
assert.NotNil(t, f)
34+
if err != nil {
35+
t.Fatalf("LoadFile failed: %v", err)
36+
}
37+
if f == nil {
38+
t.Fatal("expected file to be non-nil")
39+
}
1440

1541
for _, name := range []string{
1642
"password-with-double-quotes",
1743
"password-with-single-quotes",
1844
"password-without-quotes",
1945
} {
2046
section := f.Section(name)
21-
require.NotNil(t, section)
22-
assert.Equal(t, "%Y#X$Z", section.Key("password").String())
47+
if section == nil {
48+
t.Fatalf("expected section %q to be non-nil", name)
49+
}
50+
if got, want := section.Key("password").String(), "%Y#X$Z"; got != want {
51+
t.Errorf("password mismatch for %q: got %q, want %q", name, got, want)
52+
}
53+
}
54+
}
55+
56+
func TestConfigFile_Scopes(t *testing.T) {
57+
tests := []struct {
58+
name string
59+
profile string
60+
want []string
61+
}{
62+
{
63+
name: "empty defaults to all-apis",
64+
profile: "scope-empty",
65+
want: []string{"all-apis"},
66+
},
67+
{
68+
name: "single scope",
69+
profile: "scope-single",
70+
want: []string{"clusters"},
71+
},
72+
{
73+
name: "multiple scopes sorted",
74+
profile: "scope-multiple",
75+
want: []string{"clusters", "files:read", "iam:read", "jobs", "mlflow", "model-serving", "pipelines"},
76+
},
77+
}
78+
79+
for _, tt := range tests {
80+
t.Run(tt.name, func(t *testing.T) {
81+
withMockEnv(t, map[string]string{
82+
"HOME": "testdata/scopes",
83+
})
84+
85+
cfg := &Config{Profile: tt.profile}
86+
err := cfg.EnsureResolved()
87+
if err != nil {
88+
t.Fatalf("EnsureResolved failed: %v", err)
89+
}
90+
if diff := cmp.Diff(tt.want, cfg.GetScopes()); diff != "" {
91+
t.Errorf("GetScopes mismatch (-want +got):\n%s", diff)
92+
}
93+
})
2394
}
2495
}

config/config_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,22 @@ import (
77

88
"github.com/databricks/databricks-sdk-go/credentials/u2m"
99
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
10+
"github.com/google/go-cmp/cmp"
1011
"github.com/stretchr/testify/assert"
1112
"github.com/stretchr/testify/require"
1213
)
1314

15+
// mockLoader is a test helper that implements the Loader interface.
16+
type mockLoader func(cfg *Config) error
17+
18+
func (m mockLoader) Name() string {
19+
return "mockLoader"
20+
}
21+
22+
func (m mockLoader) Configure(cfg *Config) error {
23+
return m(cfg)
24+
}
25+
1426
func TestHostType_AwsAccount(t *testing.T) {
1527
c := &Config{
1628
Host: "https://accounts.cloud.databricks.com",
@@ -299,3 +311,68 @@ func TestConfig_getOAuthArgument_Unified(t *testing.T) {
299311
})
300312
}
301313
}
314+
315+
func TestConfig_EnsureResolved_scopeNormalization(t *testing.T) {
316+
testCases := []struct {
317+
desc string
318+
scopes []string
319+
want []string
320+
}{
321+
{
322+
desc: "nil scopes",
323+
scopes: nil,
324+
want: nil,
325+
},
326+
{
327+
desc: "empty scopes",
328+
scopes: []string{},
329+
want: []string{},
330+
},
331+
{
332+
desc: "single scope",
333+
scopes: []string{"clusters"},
334+
want: []string{"clusters"},
335+
},
336+
{
337+
desc: "already sorted no duplicates",
338+
scopes: []string{"a", "b", "c"},
339+
want: []string{"a", "b", "c"},
340+
},
341+
{
342+
desc: "unsorted scopes are sorted",
343+
scopes: []string{"jobs", "clusters", "pipelines"},
344+
want: []string{"clusters", "jobs", "pipelines"},
345+
},
346+
{
347+
desc: "duplicate scopes are removed",
348+
scopes: []string{"clusters", "jobs", "clusters", "pipelines:read", "jobs"},
349+
want: []string{"clusters", "jobs", "pipelines:read"},
350+
},
351+
{
352+
desc: "all duplicates reduced to one",
353+
scopes: []string{"all-apis", "all-apis", "all-apis"},
354+
want: []string{"all-apis"},
355+
},
356+
}
357+
358+
for _, tc := range testCases {
359+
t.Run(tc.desc, func(t *testing.T) {
360+
cfg := &Config{
361+
Host: "https://example.cloud.databricks.com",
362+
Loaders: []Loader{mockLoader(func(cfg *Config) error {
363+
cfg.Scopes = tc.scopes
364+
return nil
365+
})},
366+
}
367+
368+
err := cfg.EnsureResolved()
369+
if err != nil {
370+
t.Fatalf("EnsureResolved() error: %v", err)
371+
}
372+
373+
if diff := cmp.Diff(tc.want, cfg.Scopes); diff != "" {
374+
t.Errorf("EnsureResolved() scopes mismatch (-want +got):\n%s", diff)
375+
}
376+
})
377+
}
378+
}

0 commit comments

Comments
 (0)