Skip to content

Commit 86c8b95

Browse files
Custom scopes support in WIF authentication (#1389)
## Summary - 5-line change (in `auth_default.go` and `tokensource.go`) passing scopes through in WIF authentication instead of using "all-apis" ## Testing - Integration tests with mocking ensuring scopes are propogated and used for token exchange (`auth_default_test.go`). - Unit tests to ensure custom scopes propogated in different account and workspace configurations (`tokensource_test.go`). --- NO_CHANGELOG=true
1 parent c3088cb commit 86c8b95

File tree

4 files changed

+216
-1
lines changed

4 files changed

+216
-1
lines changed

config/auth_default.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt
149149
if cfg.HostType() != WorkspaceHost {
150150
oidcConfig.AccountID = cfg.AccountID
151151
}
152+
oidcConfig.SetScopes(cfg.GetScopes())
152153
tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig)
153154
return NewTokenSourceStrategy(name, tokenSource)
154155
}

config/auth_default_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ package config
22

33
import (
44
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
58
"strings"
69
"testing"
10+
11+
"github.com/databricks/databricks-sdk-go/credentials/u2m"
712
)
813

914
func TestDefaultCredentials_Configure(t *testing.T) {
@@ -47,3 +52,101 @@ func TestDefaultCredentials_Configure(t *testing.T) {
4752
})
4853
}
4954
}
55+
56+
func TestGithubOIDC_Scopes(t *testing.T) {
57+
const oidcTokenPath = "/oidc/v1/token"
58+
59+
tests := []struct {
60+
name string
61+
scopes []string
62+
want string
63+
}{
64+
{
65+
name: "nil scopes uses default",
66+
scopes: nil,
67+
want: "all-apis",
68+
},
69+
{
70+
name: "empty scopes uses default",
71+
scopes: []string{},
72+
want: "all-apis",
73+
},
74+
{
75+
name: "single scope",
76+
scopes: []string{"clusters"},
77+
want: "clusters",
78+
},
79+
{
80+
name: "multiple scopes are sorted",
81+
scopes: []string{"jobs", "clusters", "files:read"},
82+
want: "clusters files:read jobs",
83+
},
84+
}
85+
86+
for _, tt := range tests {
87+
t.Run(tt.name, func(t *testing.T) {
88+
// Mock GitHub server for OIDC token requests.
89+
githubServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
90+
w.Header().Set("Content-Type", "application/json")
91+
json.NewEncoder(w).Encode(map[string]string{"value": "github-id-token"})
92+
}))
93+
defer githubServer.Close()
94+
95+
// Mock Databricks server to verify the SDK passes the correct scopes.
96+
var databricksServer *httptest.Server
97+
databricksServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98+
switch r.URL.Path {
99+
case "/oidc/.well-known/oauth-authorization-server":
100+
w.Header().Set("Content-Type", "application/json")
101+
json.NewEncoder(w).Encode(u2m.OAuthAuthorizationServer{
102+
AuthorizationEndpoint: "https://host.com/oidc/v1/authorize",
103+
TokenEndpoint: databricksServer.URL + oidcTokenPath,
104+
})
105+
106+
case oidcTokenPath:
107+
if err := r.ParseForm(); err != nil {
108+
t.Fatalf("Failed to parse form: %v", err)
109+
}
110+
// The scope assertion: verifies the SDK sends the correct scope parameter.
111+
if got := r.Form.Get("scope"); got != tt.want {
112+
t.Errorf("scope: got %q, want %q", got, tt.want)
113+
}
114+
w.Header().Set("Content-Type", "application/json")
115+
json.NewEncoder(w).Encode(map[string]interface{}{
116+
"token_type": "Bearer",
117+
"access_token": "databricks-access-token",
118+
"expires_in": 3600,
119+
})
120+
121+
default:
122+
t.Errorf("Unexpected request: %s %s", r.Method, r.URL.Path)
123+
http.Error(w, "Not found", http.StatusNotFound)
124+
}
125+
}))
126+
defer databricksServer.Close()
127+
128+
cfg := &Config{
129+
Host: databricksServer.URL,
130+
ClientID: "test-client-id",
131+
ActionsIDTokenRequestURL: githubServer.URL + "/github-token?version=1",
132+
ActionsIDTokenRequestToken: "github-request-token",
133+
TokenAudience: "databricks-test-audience",
134+
AuthType: "github-oidc",
135+
Scopes: tt.scopes,
136+
}
137+
138+
req, err := http.NewRequest("GET", databricksServer.URL+"/api/test", nil)
139+
if err != nil {
140+
t.Fatalf("http.NewRequest(): unexpected error: %v", err)
141+
}
142+
err = cfg.Authenticate(req)
143+
if err != nil {
144+
t.Fatalf("Authenticate(): unexpected error: %v", err)
145+
}
146+
wantAuthHeader := "Bearer databricks-access-token"
147+
if got := req.Header.Get("Authorization"); got != wantAuthHeader {
148+
t.Errorf("Authorization header: got %q, want %q", got, wantAuthHeader)
149+
}
150+
})
151+
}
152+
}

config/experimental/auth/oidc/tokensource.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ type DatabricksOIDCTokenSourceConfig struct {
3939

4040
// IDTokenSource returns the IDToken to be used for the token exchange.
4141
IDTokenSource IDTokenSource
42+
43+
// scopes is the list of OAuth scopes to request.
44+
scopes []string
45+
}
46+
47+
// GetScopes returns the OAuth scopes to request. If no scopes have been set,
48+
// it returns the default scope "all-apis".
49+
func (c *DatabricksOIDCTokenSourceConfig) GetScopes() []string {
50+
if len(c.scopes) == 0 {
51+
return []string{"all-apis"}
52+
}
53+
return c.scopes
54+
}
55+
56+
// SetScopes sets the OAuth scopes to request.
57+
func (c *DatabricksOIDCTokenSourceConfig) SetScopes(scopes []string) {
58+
c.scopes = scopes
4259
}
4360

4461
// NewDatabricksOIDCTokenSource returns a new Databricks OIDC TokenSource.
@@ -77,11 +94,14 @@ func (w *databricksOIDCTokenSource) Token(ctx context.Context) (*oauth2.Token, e
7794
return nil, err
7895
}
7996

97+
// This nil check is to ensure backwards compatibility for users implementing their own
98+
// OIDC token source.
99+
scopes := w.cfg.GetScopes()
80100
c := &clientcredentials.Config{
81101
ClientID: w.cfg.ClientID,
82102
AuthStyle: oauth2.AuthStyleInParams,
83103
TokenURL: endpoints.TokenEndpoint,
84-
Scopes: []string{"all-apis"},
104+
Scopes: scopes,
85105
EndpointParams: url.Values{
86106
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
87107
"subject_token": {idToken.Value},

config/experimental/auth/oidc/tokensource_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,94 @@ func TestDatabricksOidcTokenSource(t *testing.T) {
319319
})
320320
}
321321
}
322+
323+
func TestWIF_Scopes(t *testing.T) {
324+
const (
325+
testClientID = "test-client-id"
326+
testIDToken = "test-id-token"
327+
testAccessToken = "test-access-token"
328+
testTokenPath = "/oidc/v1/token"
329+
testHost = "https://host.com"
330+
)
331+
332+
tests := []struct {
333+
name string
334+
scopes []string
335+
want string
336+
}{
337+
{
338+
name: "nil scopes uses default",
339+
scopes: nil,
340+
want: "all-apis",
341+
},
342+
{
343+
name: "empty scopes uses default",
344+
scopes: []string{},
345+
want: "all-apis",
346+
},
347+
{
348+
name: "single scope",
349+
scopes: []string{"dashboards"},
350+
want: "dashboards",
351+
},
352+
{
353+
name: "multiple scopes",
354+
scopes: []string{"jobs", "files:read", "mlflow"},
355+
want: "jobs files:read mlflow",
356+
},
357+
}
358+
359+
for _, tt := range tests {
360+
t.Run(tt.name, func(t *testing.T) {
361+
cfg := DatabricksOIDCTokenSourceConfig{
362+
ClientID: testClientID,
363+
Host: testHost,
364+
TokenEndpointProvider: func(ctx context.Context) (*u2m.OAuthAuthorizationServer, error) {
365+
return &u2m.OAuthAuthorizationServer{
366+
TokenEndpoint: testHost + testTokenPath,
367+
}, nil
368+
},
369+
Audience: "token-audience",
370+
IDTokenSource: IDTokenSourceFn(func(ctx context.Context, aud string) (*IDToken, error) {
371+
return &IDToken{Value: testIDToken}, nil
372+
}),
373+
scopes: tt.scopes,
374+
}
375+
376+
ts := NewDatabricksOIDCTokenSource(cfg)
377+
378+
// The scope assertion: verifies the token source sends the correct scope parameter.
379+
expectedRequest := url.Values{
380+
"client_id": {testClientID},
381+
"scope": {tt.want},
382+
"subject_token_type": {"urn:ietf:params:oauth:token-type:jwt"},
383+
"subject_token": {testIDToken},
384+
"grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"},
385+
}
386+
387+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{
388+
Transport: fixtures.MappingTransport{
389+
"POST " + testTokenPath: {
390+
Status: http.StatusOK,
391+
ExpectedHeaders: map[string]string{
392+
"Content-Type": "application/x-www-form-urlencoded",
393+
},
394+
ExpectedRequest: expectedRequest,
395+
Response: map[string]string{
396+
"token_type": "Bearer",
397+
"access_token": testAccessToken,
398+
},
399+
},
400+
},
401+
})
402+
403+
token, err := ts.Token(ctx)
404+
if err != nil {
405+
t.Fatalf("Token(ctx): got error %q, want none", err)
406+
}
407+
if token.AccessToken != testAccessToken {
408+
t.Errorf("Token(ctx): got access token %q, want %q", token.AccessToken, testAccessToken)
409+
}
410+
})
411+
}
412+
}

0 commit comments

Comments
 (0)