@@ -71,8 +71,8 @@ const baseURLKey = "ast-base-url"
7171
7272const audienceClaimKey = "aud"
7373
74- var cachedAccessToken string
75- var cachedAccessTime time.Time
74+ var CachedAccessToken string
75+ var CachedAccessTime time.Time
7676var Domains = make (map [string ]struct {})
7777
7878func retryHTTPRequest (requestFunc func () (* http.Response , error ), retries int , baseDelayInMilliSec time.Duration ) (* http.Response , error ) {
@@ -85,7 +85,12 @@ func retryHTTPRequest(requestFunc func() (*http.Response, error), retries int, b
8585 if err != nil {
8686 return nil , err
8787 }
88- if resp .StatusCode != http .StatusBadGateway {
88+ if resp .StatusCode == http .StatusBadGateway {
89+ logger .PrintIfVerbose ("Bad Gateway (502), retrying" )
90+ } else if resp .StatusCode == http .StatusUnauthorized {
91+ logger .PrintIfVerbose ("Unauthorized request (401), refreshing token" )
92+ _ , _ = configureClientCredentialsAndGetNewToken ()
93+ } else {
8994 return resp , nil
9095 }
9196 _ = resp .Body .Close ()
@@ -398,27 +403,21 @@ func GetWithQueryParamsAndCustomRequest(client *http.Client, customReq *http.Req
398403 customReq = addReqMonitor (customReq )
399404 return request (client , customReq , true )
400405}
406+
401407func GetAccessToken () (string , error ) {
402- authURI , err := GetAuthURI ()
403- if err != nil {
404- return "" , err
405- }
408+ var err error
406409 tokenExpirySeconds := viper .GetInt (commonParams .TokenExpirySecondsKey )
410+
407411 accessToken := getClientCredentialsFromCache (tokenExpirySeconds )
408- accessKeyID := viper .GetString (commonParams .AccessKeyIDConfigKey )
409- accessKeySecret := viper .GetString (commonParams .AccessKeySecretConfigKey )
410- astAPIKey := viper .GetString (commonParams .AstAPIKey )
411- if accessKeyID == "" && astAPIKey == "" {
412- return "" , errors .Errorf (fmt .Sprintf (FailedToAuth , "access key ID" ))
413- } else if accessKeySecret == "" && astAPIKey == "" {
414- return "" , errors .Errorf (fmt .Sprintf (FailedToAuth , "access key secret" ))
415- }
412+
416413 if accessToken == "" {
417- accessToken , err = getClientCredentials (accessKeyID , accessKeySecret , astAPIKey , authURI )
414+ logger .PrintIfVerbose ("Fetching API access token." )
415+ accessToken , err = configureClientCredentialsAndGetNewToken ()
418416 if err != nil {
419417 return "" , err
420418 }
421419 }
420+
422421 return accessToken , nil
423422}
424423
@@ -445,38 +444,45 @@ func enrichWithPasswordCredentials(
445444 return nil
446445}
447446
448- func getClientCredentials (accessKeyID , accessKeySecret , astAPKey , authURI string ) (string , error ) {
449- logger .PrintIfVerbose ("Fetching API access token." )
450- tokenExpirySeconds := viper .GetInt (commonParams .TokenExpirySecondsKey )
447+ func configureClientCredentialsAndGetNewToken () (string , error ) {
448+ accessKeyID := viper .GetString (commonParams .AccessKeyIDConfigKey )
449+ accessKeySecret := viper .GetString (commonParams .AccessKeySecretConfigKey )
450+ astAPIKey := viper .GetString (commonParams .AstAPIKey )
451+ var accessToken string
451452
452- var err error
453- accessToken := getClientCredentialsFromCache (tokenExpirySeconds )
453+ if accessKeyID == "" && astAPIKey == "" {
454+ return "" , errors .Errorf (fmt .Sprintf (FailedToAuth , "access key ID" ))
455+ } else if accessKeySecret == "" && astAPIKey == "" {
456+ return "" , errors .Errorf (fmt .Sprintf (FailedToAuth , "access key secret" ))
457+ }
454458
455- if accessToken == "" {
456- // If the token is present the default to that.
457- if astAPKey != "" {
458- accessToken , err = getNewToken (getAPIKeyPayload (astAPKey ), authURI )
459- } else {
460- accessToken , err = getNewToken (getCredentialsPayload (accessKeyID , accessKeySecret ), authURI )
461- }
459+ authURI , err := GetAuthURI ()
460+ if err != nil {
461+ return "" , err
462+ }
462463
463- if err != nil {
464- return "" , errors .Errorf ("%s" , err )
465- }
464+ if astAPIKey != "" {
465+ accessToken , err = getNewToken (getAPIKeyPayload (astAPIKey ), authURI )
466+ } else {
467+ accessToken , err = getNewToken (getCredentialsPayload (accessKeyID , accessKeySecret ), authURI )
468+ }
466469
467- writeCredentialsToCache (accessToken )
470+ if err != nil {
471+ return "" , errors .Errorf ("%s" , err )
468472 }
469473
474+ writeCredentialsToCache (accessToken )
475+
470476 return accessToken , nil
471477}
472478
473479func getClientCredentialsFromCache (tokenExpirySeconds int ) string {
474480 logger .PrintIfVerbose ("Checking cache for API access token." )
475481
476- expired := time .Since (cachedAccessTime ) > time .Duration (tokenExpirySeconds - expiryGraceSeconds )* time .Second
482+ expired := time .Since (CachedAccessTime ) > time .Duration (tokenExpirySeconds - expiryGraceSeconds )* time .Second
477483 if ! expired {
478484 logger .PrintIfVerbose ("Using cached API access token!" )
479- return cachedAccessToken
485+ return CachedAccessToken
480486 }
481487 logger .PrintIfVerbose ("API access token not found in cache!" )
482488 return ""
@@ -488,8 +494,8 @@ func writeCredentialsToCache(accessToken string) {
488494
489495 logger .PrintIfVerbose ("Storing API access token to cache." )
490496 viper .Set (commonParams .AstToken , accessToken )
491- cachedAccessToken = accessToken
492- cachedAccessTime = time .Now ()
497+ CachedAccessToken = accessToken
498+ CachedAccessTime = time .Now ()
493499}
494500
495501func getNewToken (credentialsPayload , authServerURI string ) (string , error ) {
@@ -555,13 +561,13 @@ func getCredentialsPayload(accessKeyID, accessKeySecret string) string {
555561
556562func getAPIKeyPayload (astToken string ) string {
557563 logger .PrintIfVerbose ("Using API key credentials." )
558-
564+
559565 clientID , err := extractAZPFromToken (astToken )
560566 if err != nil {
561567 logger .PrintIfVerbose ("Failed to extract azp from token, using default client_id" )
562568 clientID = "ast-app"
563569 }
564-
570+
565571 return fmt .Sprintf ("grant_type=refresh_token&client_id=%s&refresh_token=%s" , clientID , astToken )
566572}
567573
0 commit comments