From 857bb0dfd2c2945bbbbd45f483960a690fc65732 Mon Sep 17 00:00:00 2001 From: rjcorwin Date: Fri, 30 May 2025 15:03:16 -0400 Subject: [PATCH] Add dependency injection to ChatCompletionStream for improved testability **Describe the change** This PR refactors the `ChatCompletionStream` to use dependency injection by introducing a `ChatStreamReader` interface. This allows for injecting custom stream readers, primarily for testing purposes, making the streaming functionality more testable and maintainable. **Provide OpenAI documentation link** https://platform.openai.com/docs/api-reference/chat/create **Describe your solution** The changes include: - Added a `ChatStreamReader` interface that defines the contract for reading chat completion streams - Refactored `ChatCompletionStream` to use composition with a `ChatStreamReader` instead of embedding `streamReader` - Added `NewChatCompletionStream()` constructor function to enable dependency injection - Implemented explicit delegation methods (`Recv()`, `Close()`, `Header()`, `GetRateLimitHeaders()`) on `ChatCompletionStream` - Added interface compliance check via `var _ ChatStreamReader = (*streamReader[ChatCompletionStreamResponse])(nil)` This approach maintains backward compatibility while enabling easier mocking and testing of streaming functionality. **Tests** Added comprehensive tests demonstrating the new functionality: - `TestChatCompletionStream_MockInjection`: Tests basic mock injection with the new constructor - `mock_streaming_demo_test.go`: A complete demonstration file showing how to create mock clients and stream readers for testing, including: - `MockOpenAIStreamClient`: Full mock client implementation - `mockStreamReader`: Custom stream reader for controlled test responses - `TestMockOpenAIStreamClient_Demo`: Demonstrates assembling multiple stream chunks - `TestMockOpenAIStreamClient_ErrorHandling`: Shows error handling patterns **Additional context** This refactoring improves the testability of code that depends on go-openai streaming without introducing breaking changes. The existing public API remains unchanged, but now supports dependency injection for testing scenarios. The new demo test file serves as documentation for users who want to mock streaming responses in their own tests. Lint fix --- chat_stream.go | 45 +++++++- chat_stream_test.go | 28 +++++ mock_streaming_demo_test.go | 199 ++++++++++++++++++++++++++++++++++++ stream_reader.go | 2 + 4 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 mock_streaming_demo_test.go diff --git a/chat_stream.go b/chat_stream.go index 80d16cc63..7b0bc40c2 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -65,10 +65,21 @@ type ChatCompletionStreamResponse struct { Usage *Usage `json:"usage,omitempty"` } +// ChatStreamReader is an interface for reading chat completion streams. +type ChatStreamReader interface { + Recv() (ChatCompletionStreamResponse, error) + Close() error +} + // ChatCompletionStream // Note: Perhaps it is more elegant to abstract Stream using generics. type ChatCompletionStream struct { - *streamReader[ChatCompletionStreamResponse] + reader ChatStreamReader +} + +// NewChatCompletionStream allows injecting a custom ChatStreamReader (for testing). +func NewChatCompletionStream(reader ChatStreamReader) *ChatCompletionStream { + return &ChatCompletionStream{reader: reader} } // CreateChatCompletionStream — API call to create a chat completion w/ streaming @@ -106,7 +117,37 @@ func (c *Client) CreateChatCompletionStream( return } stream = &ChatCompletionStream{ - streamReader: resp, + reader: resp, } return } + +func (s *ChatCompletionStream) Recv() (ChatCompletionStreamResponse, error) { + return s.reader.Recv() +} + +func (s *ChatCompletionStream) Close() error { + return s.reader.Close() +} + +func (s *ChatCompletionStream) Header() http.Header { + if h, ok := s.reader.(interface{ Header() http.Header }); ok { + return h.Header() + } + return http.Header{} +} + +func (s *ChatCompletionStream) GetRateLimitHeaders() map[string]interface{} { + if h, ok := s.reader.(interface{ GetRateLimitHeaders() RateLimitHeaders }); ok { + headers := h.GetRateLimitHeaders() + return map[string]interface{}{ + "x-ratelimit-limit-requests": headers.LimitRequests, + "x-ratelimit-limit-tokens": headers.LimitTokens, + "x-ratelimit-remaining-requests": headers.RemainingRequests, + "x-ratelimit-remaining-tokens": headers.RemainingTokens, + "x-ratelimit-reset-requests": headers.ResetRequests.String(), + "x-ratelimit-reset-tokens": headers.ResetTokens.String(), + } + } + return map[string]interface{}{} +} diff --git a/chat_stream_test.go b/chat_stream_test.go index eabb0f3a2..65d92a702 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -767,6 +767,34 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { } } +type mockStream struct { + calls int +} + +// Implement ChatStreamReader. +func (m *mockStream) Recv() (openai.ChatCompletionStreamResponse, error) { + m.calls++ + if m.calls == 1 { + return openai.ChatCompletionStreamResponse{ID: "mock1"}, nil + } + return openai.ChatCompletionStreamResponse{}, io.EOF +} +func (m *mockStream) Close() error { return nil } + +func TestChatCompletionStream_MockInjection(t *testing.T) { + mock := &mockStream{} + stream := openai.NewChatCompletionStream(mock) + + resp, err := stream.Recv() + if err != nil || resp.ID != "mock1" { + t.Errorf("expected mock1, got %v, err %v", resp.ID, err) + } + _, err = stream.Recv() + if !errors.Is(err, io.EOF) { + t.Errorf("expected EOF, got %v", err) + } +} + // Helper funcs. func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/mock_streaming_demo_test.go b/mock_streaming_demo_test.go new file mode 100644 index 000000000..d235766f2 --- /dev/null +++ b/mock_streaming_demo_test.go @@ -0,0 +1,199 @@ +package openai_test + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/sashabaranov/go-openai" +) + +// This file demonstrates how to create mock clients for go-openai streaming +// functionality. This pattern is useful when testing code that depends on +// go-openai streaming but you want to control the responses for testing. + +// MockOpenAIStreamClient demonstrates how to create a full mock client for go-openai. +type MockOpenAIStreamClient struct { + // Configure canned responses + ChatCompletionResponse openai.ChatCompletionResponse + ChatCompletionStreamErr error + + // Allow function overrides for more complex scenarios + CreateChatCompletionStreamFn func( + ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionStream, error) +} + +func (m *MockOpenAIStreamClient) CreateChatCompletionStream( + ctx context.Context, + req openai.ChatCompletionRequest, +) (*openai.ChatCompletionStream, error) { + if m.CreateChatCompletionStreamFn != nil { + return m.CreateChatCompletionStreamFn(ctx, req) + } + return nil, m.ChatCompletionStreamErr +} + +// mockStreamReader creates specific responses for testing. +type mockStreamReader struct { + responses []openai.ChatCompletionStreamResponse + index int +} + +func (m *mockStreamReader) Recv() (openai.ChatCompletionStreamResponse, error) { + if m.index >= len(m.responses) { + return openai.ChatCompletionStreamResponse{}, io.EOF + } + resp := m.responses[m.index] + m.index++ + return resp, nil +} + +func (m *mockStreamReader) Close() error { + return nil +} + +func TestMockOpenAIStreamClient_Demo(t *testing.T) { + // Create expected responses that our mock stream will return + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Model: "gpt-3.5-turbo", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Model: "gpt-3.5-turbo", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " World", + }, + }, + }, + }, + { + ID: "test-3", + Object: "chat.completion.chunk", + Model: "gpt-3.5-turbo", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + // Create mock client with custom stream function + mockClient := &MockOpenAIStreamClient{ + CreateChatCompletionStreamFn: func( + _ context.Context, _ openai.ChatCompletionRequest, + ) (*openai.ChatCompletionStream, error) { + // Create a mock stream reader with our expected responses + mockStreamReader := &mockStreamReader{ + responses: expectedResponses, + index: 0, + } + // Return a new ChatCompletionStream with our mock reader + return openai.NewChatCompletionStream(mockStreamReader), nil + }, + } + + // Test the mock client + stream, err := mockClient.CreateChatCompletionStream( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + if err != nil { + t.Fatalf("CreateChatCompletionStream returned error: %v", err) + } + defer stream.Close() + + // Verify we get back exactly the responses we configured + fullResponse := "" + for i, expectedResponse := range expectedResponses { + receivedResponse, streamErr := stream.Recv() + if streamErr != nil { + t.Fatalf("stream.Recv() failed at index %d: %v", i, streamErr) + } + + // Additional specific checks + if receivedResponse.ID != expectedResponse.ID { + t.Errorf("Response %d ID mismatch. Expected: %s, Got: %s", + i, expectedResponse.ID, receivedResponse.ID) + } + if len(receivedResponse.Choices) > 0 && len(expectedResponse.Choices) > 0 { + expectedContent := expectedResponse.Choices[0].Delta.Content + receivedContent := receivedResponse.Choices[0].Delta.Content + if receivedContent != expectedContent { + t.Errorf("Response %d content mismatch. Expected: %s, Got: %s", + i, expectedContent, receivedContent) + } + fullResponse += receivedContent + } + } + + // Verify EOF at the end + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("Expected EOF at end of stream, got: %v", streamErr) + } + + // Verify the full assembled response + expectedFullResponse := "Hello World" + if fullResponse != expectedFullResponse { + t.Errorf("Full response mismatch. Expected: %s, Got: %s", expectedFullResponse, fullResponse) + } + + t.Log("✅ Successfully demonstrated mock OpenAI client with streaming responses!") + t.Logf(" Full response assembled: %q", fullResponse) +} + +// TestMockOpenAIStreamClient_ErrorHandling demonstrates error handling. +func TestMockOpenAIStreamClient_ErrorHandling(t *testing.T) { + expectedError := errors.New("mock stream error") + + mockClient := &MockOpenAIStreamClient{ + ChatCompletionStreamErr: expectedError, + } + + _, err := mockClient.CreateChatCompletionStream( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + if !errors.Is(err, expectedError) { + t.Errorf("Expected error %v, got %v", expectedError, err) + } + + t.Log("✅ Successfully demonstrated mock OpenAI client error handling!") +} diff --git a/stream_reader.go b/stream_reader.go index 6faefe0a7..4dbcfc4b6 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -16,6 +16,8 @@ var ( errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) ) +var _ ChatStreamReader = (*streamReader[ChatCompletionStreamResponse])(nil) + type streamable interface { ChatCompletionStreamResponse | CompletionResponse }