Skip to content

Commit 11ae8fe

Browse files
Add Request.WithRetryFunc (#435)
1 parent 2c5f15c commit 11ae8fe

File tree

3 files changed

+189
-6
lines changed

3 files changed

+189
-6
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ e.POST("/path").
395395
Expect().
396396
Status(http.StatusOK)
397397

398-
// custom retry policy
398+
// custom built-in retry policy
399399
e.POST("/path").
400400
WithMaxRetries(5).
401401
WithRetryPolicy(httpexpect.RetryAllErrors).
@@ -408,6 +408,15 @@ e.POST("/path").
408408
WithRetryDelay(time.Second, time.Minute).
409409
Expect().
410410
Status(http.StatusOK)
411+
412+
// custom user-defined retry policy
413+
e.POST("/path").
414+
WithMaxRetries(5).
415+
WithRetryPolicyFunc(func(resp *http.Response, err error) bool {
416+
return resp.StatusCode == http.StatusTeapot
417+
}).
418+
Expect().
419+
Status(http.StatusOK)
411420
```
412421

413422
##### Subdomains and per-request URL

request.go

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ type Request struct {
3737
redirectPolicy RedirectPolicy
3838
maxRedirects int
3939

40-
retryPolicy RetryPolicy
41-
maxRetries int
42-
minRetryDelay time.Duration
43-
maxRetryDelay time.Duration
44-
sleepFn func(d time.Duration) <-chan time.Time
40+
retryPolicy RetryPolicy
41+
withRetryPolicyCalled bool
42+
maxRetries int
43+
minRetryDelay time.Duration
44+
maxRetryDelay time.Duration
45+
sleepFn func(d time.Duration) <-chan time.Time
46+
retryPolicyFn func(*http.Response, error) bool
4547

4648
timeout time.Duration
4749

@@ -756,7 +758,66 @@ func (r *Request) WithRetryPolicy(policy RetryPolicy) *Request {
756758
return r
757759
}
758760

761+
if r.retryPolicyFn != nil {
762+
opChain.fail(AssertionFailure{
763+
Type: AssertUsage,
764+
Errors: []error{
765+
fmt.Errorf("expected: " +
766+
"WithRetryPolicyFunc() and WithRetryPolicy() should be mutual exclusive, " +
767+
"WithRetryPolicyFunc() is already called"),
768+
},
769+
})
770+
return r
771+
}
772+
759773
r.retryPolicy = policy
774+
r.withRetryPolicyCalled = true
775+
776+
return r
777+
}
778+
779+
// WithRetryPolicyFunc sets a function to replace built-in policies
780+
// with user-defined policy.
781+
//
782+
// The function expects you to return true to perform a retry. And false to
783+
// not perform a retry.
784+
//
785+
// Example:
786+
//
787+
// req := NewRequestC(config, "POST", "/path")
788+
// req.WithRetryPolicyFunc(func(res *http.Response, err error) bool {
789+
// return resp.StatusCode == http.StatusTeapot
790+
// })
791+
func (r *Request) WithRetryPolicyFunc(
792+
fn func(res *http.Response, err error) bool,
793+
) *Request {
794+
opChain := r.chain.enter("WithRetryPolicyFunc()")
795+
defer opChain.leave()
796+
797+
r.mu.Lock()
798+
defer r.mu.Unlock()
799+
800+
if opChain.failed() {
801+
return r
802+
}
803+
804+
if !r.checkOrder(opChain, "WithRetryPolicyFunc()") {
805+
return r
806+
}
807+
808+
if r.withRetryPolicyCalled {
809+
opChain.fail(AssertionFailure{
810+
Type: AssertUsage,
811+
Errors: []error{
812+
fmt.Errorf("expected: " +
813+
"WithRetryPolicyFunc() and WithRetryPolicy() should be mutual exclusive, " +
814+
"WithRetryPolicy() is already called"),
815+
},
816+
})
817+
return r
818+
}
819+
820+
r.retryPolicyFn = fn
760821

761822
return r
762823
}
@@ -2339,6 +2400,10 @@ func (r *Request) retryRequest(reqFunc func() (*http.Response, error)) (
23392400
}
23402401

23412402
func (r *Request) shouldRetry(resp *http.Response, err error) bool {
2403+
if r.retryPolicyFn != nil {
2404+
return r.retryPolicyFn(resp, err)
2405+
}
2406+
23422407
var (
23432408
isTemporaryNetworkError bool // Deprecated
23442409
isTimeoutError bool

request_test.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3367,6 +3367,60 @@ func TestRequest_RetriesCancellation(t *testing.T) {
33673367
assert.Equal(t, 1, callCount)
33683368
}
33693369

3370+
func TestRequest_WithRetryPolicyFunc(t *testing.T) {
3371+
tests := []struct {
3372+
name string
3373+
fn func(res *http.Response, err error) bool
3374+
callCount int
3375+
}{
3376+
{
3377+
name: "should not retry",
3378+
fn: func(res *http.Response, err error) bool {
3379+
return false
3380+
},
3381+
callCount: 1,
3382+
},
3383+
{
3384+
name: "should retry",
3385+
fn: func(res *http.Response, err error) bool {
3386+
return true
3387+
},
3388+
callCount: 2,
3389+
},
3390+
}
3391+
3392+
for _, tt := range tests {
3393+
t.Run(tt.name, func(t *testing.T) {
3394+
callCount := 0
3395+
3396+
client := &mockClient{
3397+
resp: http.Response{
3398+
StatusCode: http.StatusTeapot,
3399+
},
3400+
cb: func(req *http.Request) {
3401+
callCount++
3402+
},
3403+
}
3404+
3405+
cfg := Config{
3406+
Client: client,
3407+
Reporter: newMockReporter(t),
3408+
}
3409+
3410+
req := NewRequestC(cfg, http.MethodGet, "/url").
3411+
WithMaxRetries(1).
3412+
WithRetryDelay(0, 0).
3413+
WithRetryPolicyFunc(tt.fn)
3414+
req.chain.assert(t, success)
3415+
3416+
resp := req.Expect()
3417+
resp.chain.assert(t, success)
3418+
3419+
assert.Equal(t, tt.callCount, callCount)
3420+
})
3421+
}
3422+
}
3423+
33703424
func TestRequest_Conflicts(t *testing.T) {
33713425
client := &mockClient{}
33723426

@@ -3510,6 +3564,44 @@ func TestRequest_Conflicts(t *testing.T) {
35103564
})
35113565
}
35123566
})
3567+
3568+
t.Run("retry policy conflict", func(t *testing.T) {
3569+
cases := []struct {
3570+
name string
3571+
fn func(req *Request)
3572+
}{
3573+
{
3574+
"WithRetryPolicyFunc",
3575+
func(req *Request) {
3576+
req.WithRetryPolicyFunc(func(res *http.Response, err error) bool {
3577+
return res.StatusCode == http.StatusTeapot
3578+
})
3579+
},
3580+
},
3581+
}
3582+
3583+
for _, tc := range cases {
3584+
t.Run(tc.name, func(t *testing.T) {
3585+
req := NewRequestC(config, "GET", "url")
3586+
3587+
tc.fn(req)
3588+
req.chain.assert(t, success)
3589+
3590+
req.WithRetryPolicy(RetryAllErrors)
3591+
req.chain.assert(t, failure)
3592+
})
3593+
3594+
t.Run(tc.name+" - reversed", func(t *testing.T) {
3595+
req := NewRequestC(config, "GET", "url")
3596+
3597+
req.WithRetryPolicy(RetryAllErrors)
3598+
req.chain.assert(t, success)
3599+
3600+
tc.fn(req)
3601+
req.chain.assert(t, failure)
3602+
})
3603+
}
3604+
})
35133605
}
35143606

35153607
func TestRequest_Usage(t *testing.T) {
@@ -3660,6 +3752,15 @@ func TestRequest_Usage(t *testing.T) {
36603752
prepFails: false,
36613753
expectFails: true,
36623754
},
3755+
{
3756+
name: "WithRetryPolicyFunc - nil argument",
3757+
client: &mockClient{},
3758+
prepFunc: func(req *Request) {
3759+
req.WithRetryPolicyFunc(nil)
3760+
},
3761+
prepFails: false,
3762+
expectFails: false,
3763+
},
36633764
}
36643765

36653766
for _, tc := range cases {
@@ -3952,6 +4053,14 @@ func TestRequest_Order(t *testing.T) {
39524053
req.WithMultipart()
39534054
},
39544055
},
4056+
{
4057+
name: "WithRetryPolicyFunc after Expect",
4058+
afterFunc: func(req *Request) {
4059+
req.WithRetryPolicyFunc(func(res *http.Response, err error) bool {
4060+
return res.StatusCode == http.StatusTeapot
4061+
})
4062+
},
4063+
},
39554064
}
39564065

39574066
for _, tc := range cases {

0 commit comments

Comments
 (0)