diff --git a/retry.go b/retry.go index 5338985..28499cb 100644 --- a/retry.go +++ b/retry.go @@ -153,6 +153,10 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( return emptyT, err } + if errors.Is(err, context.Cause(config.context)) { + return emptyT, err + } + lastErr = err config.onRetry(n, err) @@ -184,7 +188,7 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( errorLog = append(errorLog, unpackUnrecoverable(err)) - if !config.retryIf(err) { + if !config.retryIf(err) || errors.Is(err, context.Cause(config.context)) { break } diff --git a/retry_test.go b/retry_test.go index 1ee3739..e4b62ff 100644 --- a/retry_test.go +++ b/retry_test.go @@ -642,3 +642,42 @@ func TestIsRecoverable(t *testing.T) { err = fmt.Errorf("wrapping: %w", err) assert.False(t, IsRecoverable(err)) } + +func TestNoRetryIfCtxCanceledAtemptsZero(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var called bool + err := Do( + func() error { + cancel() + return context.Cause(ctx) + }, + OnRetry(func(n uint, err error) { + called = true + }), + Context(ctx), + Attempts(0), + ) + assert.Equal(t, context.Canceled, err) + assert.False(t, called, "OnRetry was called after cancelation") +} + +func TestNoRetryIfCtxCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var called bool + err := Do( + func() error { + cancel() + return context.Cause(ctx) + }, + OnRetry(func(n uint, err error) { + called = true + }), + Context(ctx), + Attempts(5), + LastErrorOnly(true), + ) + assert.Equal(t, context.Canceled, err) + assert.False(t, called, "OnRetry was called after cancelation") +}