1
+ package config
2
+
3
+ import (
4
+ "context"
5
+ "errors"
6
+ "net/http"
7
+ "strings"
8
+ "testing"
9
+ "time"
10
+ )
11
+
12
+ func TestDoWithBackoff (t * testing.T ) {
13
+ tests := []struct {
14
+ name string
15
+ maxRetries int
16
+ ctxTimeout time.Duration
17
+ handler func (req * http.Request ) (* http.Response , error )
18
+ expectErr string
19
+ expectCalls int
20
+ expectSuccess bool
21
+ }{
22
+ {
23
+ name : "success on first try" ,
24
+ maxRetries : 3 ,
25
+ handler : func (req * http.Request ) (* http.Response , error ) {
26
+ return & http.Response {StatusCode : 200 , Body : http .NoBody }, nil
27
+ },
28
+ expectErr : "" ,
29
+ expectCalls : 1 ,
30
+ expectSuccess : true ,
31
+ },
32
+ {
33
+ name : "max retries exceeded" ,
34
+ maxRetries : 2 ,
35
+ handler : func (req * http.Request ) (* http.Response , error ) {
36
+ return nil , errors .New ("mock error" )
37
+ },
38
+ expectErr : "max retries exceeded" ,
39
+ expectCalls : 3 ,
40
+ },
41
+ {
42
+ name : "context cancelled before success" ,
43
+ maxRetries : 0 ,
44
+ ctxTimeout : 50 * time .Millisecond ,
45
+ handler : func (req * http.Request ) (* http.Response , error ) {
46
+ return nil , errors .New ("fail" )
47
+ },
48
+ expectErr : "context deadline exceeded" ,
49
+ expectCalls : - 1 ,
50
+ },
51
+ }
52
+
53
+ for _ , tt := range tests {
54
+ t .Run (tt .name , func (t * testing.T ) {
55
+ mock := & mockRoundTripper {handler : tt .handler }
56
+ client := & http.Client {Transport : mock }
57
+ req , _ := http .NewRequest ("GET" , "http://example.com" , nil )
58
+
59
+ ctx := context .Background ()
60
+ if tt .ctxTimeout > 0 {
61
+ var cancel context.CancelFunc
62
+ ctx , cancel = context .WithTimeout (ctx , tt .ctxTimeout )
63
+ defer cancel ()
64
+ }
65
+
66
+ resp , err := DoWithBackoff (ctx , client , req , tt .maxRetries )
67
+
68
+ if tt .expectErr == "" && err != nil {
69
+ t .Fatalf ("expected success, got error: %v" , err )
70
+ }
71
+ if tt .expectErr != "" {
72
+ if err == nil || ! strings .Contains (err .Error (), tt .expectErr ) {
73
+ t .Fatalf ("expected error containing %q, got %v" , tt .expectErr , err )
74
+ }
75
+ }
76
+ if tt .expectSuccess && resp == nil {
77
+ t .Fatalf ("expected response, got nil" )
78
+ }
79
+
80
+ if tt .expectCalls >= 0 && mock .calls != tt .expectCalls {
81
+ t .Errorf ("expected %d calls, got %d" , tt .expectCalls , mock .calls )
82
+ }
83
+ })
84
+ }
85
+ }
0 commit comments