Skip to content

Commit 44628c5

Browse files
committed
add tests
1 parent e0c224d commit 44628c5

7 files changed

+619
-28
lines changed

auth/auth_test.go

+302
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
package auth
2+
3+
import (
4+
"errors"
5+
"testing"
6+
"time"
7+
)
8+
9+
type mockStreamingProvider struct {
10+
credentials Credentials
11+
err error
12+
updates chan Credentials
13+
}
14+
15+
func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
16+
return &mockStreamingProvider{
17+
credentials: initialCreds,
18+
updates: make(chan Credentials, 10),
19+
}
20+
}
21+
22+
func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
23+
if m.err != nil {
24+
return nil, nil, m.err
25+
}
26+
27+
// Send initial credentials
28+
listener.OnNext(m.credentials)
29+
30+
// Start goroutine to handle updates
31+
go func() {
32+
for creds := range m.updates {
33+
listener.OnNext(creds)
34+
}
35+
}()
36+
37+
return m.credentials, func() error {
38+
close(m.updates)
39+
return nil
40+
}, nil
41+
}
42+
43+
func TestStreamingCredentialsProvider(t *testing.T) {
44+
t.Run("successful subscription", func(t *testing.T) {
45+
initialCreds := NewBasicCredentials("user1", "pass1")
46+
provider := newMockStreamingProvider(initialCreds)
47+
48+
var receivedCreds []Credentials
49+
var receivedErrors []error
50+
51+
listener := NewReAuthCredentialsListener(
52+
func(creds Credentials) error {
53+
receivedCreds = append(receivedCreds, creds)
54+
return nil
55+
},
56+
func(err error) {
57+
receivedErrors = append(receivedErrors, err)
58+
},
59+
)
60+
61+
creds, cancel, err := provider.Subscribe(listener)
62+
if err != nil {
63+
t.Fatalf("unexpected error: %v", err)
64+
}
65+
if cancel == nil {
66+
t.Fatal("expected cancel function to be non-nil")
67+
}
68+
if creds != initialCreds {
69+
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
70+
}
71+
if len(receivedCreds) != 1 {
72+
t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
73+
}
74+
if receivedCreds[0] != initialCreds {
75+
t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
76+
}
77+
if len(receivedErrors) != 0 {
78+
t.Fatalf("expected no errors, got %d", len(receivedErrors))
79+
}
80+
81+
// Send an update
82+
newCreds := NewBasicCredentials("user2", "pass2")
83+
provider.updates <- newCreds
84+
85+
// Wait for update to be processed
86+
time.Sleep(100 * time.Millisecond)
87+
if len(receivedCreds) != 2 {
88+
t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
89+
}
90+
if receivedCreds[1] != newCreds {
91+
t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
92+
}
93+
94+
// Cancel subscription
95+
if err := cancel(); err != nil {
96+
t.Fatalf("unexpected error cancelling subscription: %v", err)
97+
}
98+
})
99+
100+
t.Run("subscription error", func(t *testing.T) {
101+
provider := &mockStreamingProvider{
102+
err: errors.New("subscription failed"),
103+
}
104+
105+
var receivedCreds []Credentials
106+
var receivedErrors []error
107+
108+
listener := NewReAuthCredentialsListener(
109+
func(creds Credentials) error {
110+
receivedCreds = append(receivedCreds, creds)
111+
return nil
112+
},
113+
func(err error) {
114+
receivedErrors = append(receivedErrors, err)
115+
},
116+
)
117+
118+
creds, cancel, err := provider.Subscribe(listener)
119+
if err == nil {
120+
t.Fatal("expected error, got nil")
121+
}
122+
if cancel != nil {
123+
t.Fatal("expected cancel function to be nil")
124+
}
125+
if creds != nil {
126+
t.Fatalf("expected nil credentials, got %v", creds)
127+
}
128+
if len(receivedCreds) != 0 {
129+
t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
130+
}
131+
if len(receivedErrors) != 0 {
132+
t.Fatalf("expected no errors, got %d", len(receivedErrors))
133+
}
134+
})
135+
136+
t.Run("re-auth error", func(t *testing.T) {
137+
initialCreds := NewBasicCredentials("user1", "pass1")
138+
provider := newMockStreamingProvider(initialCreds)
139+
140+
reauthErr := errors.New("re-auth failed")
141+
var receivedErrors []error
142+
143+
listener := NewReAuthCredentialsListener(
144+
func(creds Credentials) error {
145+
return reauthErr
146+
},
147+
func(err error) {
148+
receivedErrors = append(receivedErrors, err)
149+
},
150+
)
151+
152+
creds, cancel, err := provider.Subscribe(listener)
153+
if err != nil {
154+
t.Fatalf("unexpected error: %v", err)
155+
}
156+
if cancel == nil {
157+
t.Fatal("expected cancel function to be non-nil")
158+
}
159+
if creds != initialCreds {
160+
t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
161+
}
162+
if len(receivedErrors) != 1 {
163+
t.Fatalf("expected 1 error, got %d", len(receivedErrors))
164+
}
165+
if receivedErrors[0] != reauthErr {
166+
t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
167+
}
168+
169+
if err := cancel(); err != nil {
170+
t.Fatalf("unexpected error cancelling subscription: %v", err)
171+
}
172+
})
173+
}
174+
175+
func TestBasicCredentials(t *testing.T) {
176+
t.Run("basic auth", func(t *testing.T) {
177+
creds := NewBasicCredentials("user1", "pass1")
178+
username, password := creds.BasicAuth()
179+
if username != "user1" {
180+
t.Fatalf("expected username 'user1', got '%s'", username)
181+
}
182+
if password != "pass1" {
183+
t.Fatalf("expected password 'pass1', got '%s'", password)
184+
}
185+
})
186+
187+
t.Run("raw credentials", func(t *testing.T) {
188+
creds := NewBasicCredentials("user1", "pass1")
189+
raw := creds.RawCredentials()
190+
expected := "user1:pass1"
191+
if raw != expected {
192+
t.Fatalf("expected raw credentials '%s', got '%s'", expected, raw)
193+
}
194+
})
195+
196+
t.Run("empty username", func(t *testing.T) {
197+
creds := NewBasicCredentials("", "pass1")
198+
username, password := creds.BasicAuth()
199+
if username != "" {
200+
t.Fatalf("expected empty username, got '%s'", username)
201+
}
202+
if password != "pass1" {
203+
t.Fatalf("expected password 'pass1', got '%s'", password)
204+
}
205+
})
206+
}
207+
208+
func TestReAuthCredentialsListener(t *testing.T) {
209+
t.Run("successful re-auth", func(t *testing.T) {
210+
var reAuthCalled bool
211+
var onErrCalled bool
212+
var receivedCreds Credentials
213+
214+
listener := NewReAuthCredentialsListener(
215+
func(creds Credentials) error {
216+
reAuthCalled = true
217+
receivedCreds = creds
218+
return nil
219+
},
220+
func(err error) {
221+
onErrCalled = true
222+
},
223+
)
224+
225+
creds := NewBasicCredentials("user1", "pass1")
226+
listener.OnNext(creds)
227+
228+
if !reAuthCalled {
229+
t.Fatal("expected reAuth to be called")
230+
}
231+
if onErrCalled {
232+
t.Fatal("expected onErr not to be called")
233+
}
234+
if receivedCreds != creds {
235+
t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
236+
}
237+
})
238+
239+
t.Run("re-auth error", func(t *testing.T) {
240+
var reAuthCalled bool
241+
var onErrCalled bool
242+
var receivedErr error
243+
expectedErr := errors.New("re-auth failed")
244+
245+
listener := NewReAuthCredentialsListener(
246+
func(creds Credentials) error {
247+
reAuthCalled = true
248+
return expectedErr
249+
},
250+
func(err error) {
251+
onErrCalled = true
252+
receivedErr = err
253+
},
254+
)
255+
256+
creds := NewBasicCredentials("user1", "pass1")
257+
listener.OnNext(creds)
258+
259+
if !reAuthCalled {
260+
t.Fatal("expected reAuth to be called")
261+
}
262+
if !onErrCalled {
263+
t.Fatal("expected onErr to be called")
264+
}
265+
if receivedErr != expectedErr {
266+
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
267+
}
268+
})
269+
270+
t.Run("on error", func(t *testing.T) {
271+
var onErrCalled bool
272+
var receivedErr error
273+
expectedErr := errors.New("provider error")
274+
275+
listener := NewReAuthCredentialsListener(
276+
func(creds Credentials) error {
277+
return nil
278+
},
279+
func(err error) {
280+
onErrCalled = true
281+
receivedErr = err
282+
},
283+
)
284+
285+
listener.OnError(expectedErr)
286+
287+
if !onErrCalled {
288+
t.Fatal("expected onErr to be called")
289+
}
290+
if receivedErr != expectedErr {
291+
t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
292+
}
293+
})
294+
295+
t.Run("nil callbacks", func(t *testing.T) {
296+
listener := NewReAuthCredentialsListener(nil, nil)
297+
298+
// Should not panic
299+
listener.OnNext(NewBasicCredentials("user1", "pass1"))
300+
listener.OnError(errors.New("test error"))
301+
})
302+
}

command_recorder_test.go

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package redis_test
2+
3+
import (
4+
"context"
5+
"strings"
6+
"sync"
7+
8+
"github.com/redis/go-redis/v9"
9+
)
10+
11+
// commandRecorder records the last N commands executed by a Redis client.
12+
type commandRecorder struct {
13+
mu sync.Mutex
14+
commands []string
15+
maxSize int
16+
}
17+
18+
// newCommandRecorder creates a new command recorder with the specified maximum size.
19+
func newCommandRecorder(maxSize int) *commandRecorder {
20+
return &commandRecorder{
21+
commands: make([]string, 0, maxSize),
22+
maxSize: maxSize,
23+
}
24+
}
25+
26+
// Record adds a command to the recorder.
27+
func (r *commandRecorder) Record(cmd string) {
28+
cmd = strings.ToLower(cmd)
29+
r.mu.Lock()
30+
defer r.mu.Unlock()
31+
32+
r.commands = append(r.commands, cmd)
33+
if len(r.commands) > r.maxSize {
34+
r.commands = r.commands[1:]
35+
}
36+
}
37+
38+
// LastCommands returns a copy of the recorded commands.
39+
func (r *commandRecorder) LastCommands() []string {
40+
r.mu.Lock()
41+
defer r.mu.Unlock()
42+
return append([]string(nil), r.commands...)
43+
}
44+
45+
// Contains checks if the recorder contains a specific command.
46+
func (r *commandRecorder) Contains(cmd string) bool {
47+
cmd = strings.ToLower(cmd)
48+
r.mu.Lock()
49+
defer r.mu.Unlock()
50+
for _, c := range r.commands {
51+
if strings.Contains(c, cmd) {
52+
return true
53+
}
54+
}
55+
return false
56+
}
57+
58+
// Hook returns a Redis hook that records commands.
59+
func (r *commandRecorder) Hook() redis.Hook {
60+
return &commandHook{recorder: r}
61+
}
62+
63+
// commandHook implements the redis.Hook interface to record commands.
64+
type commandHook struct {
65+
recorder *commandRecorder
66+
}
67+
68+
func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
69+
return next
70+
}
71+
72+
func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
73+
return func(ctx context.Context, cmd redis.Cmder) error {
74+
h.recorder.Record(cmd.String())
75+
return next(ctx, cmd)
76+
}
77+
}
78+
79+
func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
80+
return func(ctx context.Context, cmds []redis.Cmder) error {
81+
for _, cmd := range cmds {
82+
h.recorder.Record(cmd.String())
83+
}
84+
return next(ctx, cmds)
85+
}
86+
}

internal/internal.go

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"github.com/redis/go-redis/v9/internal/rand"
77
)
88

9+
type ParentHooksMixinKey struct{}
10+
911
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
1012
if retry < 0 {
1113
panic("not reached")

0 commit comments

Comments
 (0)