Skip to content

Commit bff252b

Browse files
chore: tests fixes
1 parent fbdd429 commit bff252b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1320
-196
lines changed

.github/workflows/scripts/release-core.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ fi
3232
# Building core
3333
go mod download
3434
go build ./...
35-
go test ./...
3635
cd ..
3736
echo "✅ Core build validation successful"
3837

38+
# Run core provider tests
39+
echo "🔧 Running core provider tests..."
40+
cd tests/core-providers
41+
go test -v ./...
42+
cd ../..
3943

4044
# Capturing changelog
4145
CHANGELOG_BODY=$(cat core/changelog.md)

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ go.work.sum
3333
# Sqlite DBs
3434
*.db
3535
*.db-shm
36-
*.db-wal
36+
*.db-wal
37+
38+
.claude

core/bifrost.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,11 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
12291229
primaryResult, primaryErr := bifrost.tryRequest(ctx, req)
12301230

12311231
if primaryErr != nil {
1232-
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", provider, model, primaryErr))
1232+
if primaryErr.Error != nil {
1233+
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %s", provider, model, primaryErr.Error.Message))
1234+
} else {
1235+
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", provider, model, primaryErr))
1236+
}
12331237
if len(fallbacks) > 0 {
12341238
bifrost.logger.Debug(fmt.Sprintf("Check if we should try %d fallbacks", len(fallbacks)))
12351239
}

core/changelog.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<!-- The pattern we follow here is to keep the changelog for the latest version -->
22
<!-- Old changelogs are automatically attached to the GitHub releases -->
33

4-
- fix: openai specific parameters filtered for openai compatibile providers
5-
- fix: error response unmarshalling for gemini provider
6-
- BREAKING FIX: json_schema field correctly renamed to schema; ResponsesTextConfigFormatJSONSchema restructured
4+
- bug: fixed embedding request not being handled in `GetExtraFields()` method of `BifrostResponse`
5+
- fix: added latency calculation for vertex native requests
6+
- feat: added cached tokens and reasoning tokens to the usage metadata for chat completions

core/providers/openai_test.go

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
package providers
2+
3+
import (
4+
"context"
5+
"os"
6+
"testing"
7+
8+
"github.com/maximhq/bifrost/core/schemas"
9+
)
10+
11+
func TestOpenAIChatCompletion(t *testing.T) {
12+
apiKey := os.Getenv("OPENAI_API_KEY")
13+
if apiKey == "" {
14+
t.Skip("OPENAI_API_KEY not set")
15+
}
16+
17+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
18+
NetworkConfig: schemas.NetworkConfig{
19+
BaseURL: "https://api.openai.com",
20+
DefaultRequestTimeoutInSeconds: 30,
21+
},
22+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
23+
Concurrency: 10,
24+
},
25+
}, newTestLogger())
26+
27+
ctx := context.Background()
28+
key := schemas.Key{Value: apiKey}
29+
30+
request := &schemas.BifrostChatRequest{
31+
Provider: schemas.OpenAI,
32+
Model: "gpt-3.5-turbo",
33+
Input: []schemas.ChatMessage{
34+
{
35+
Role: "user",
36+
Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")},
37+
},
38+
},
39+
Params: &schemas.ChatParameters{
40+
Temperature: float64Ptr(0.7),
41+
MaxCompletionTokens: intPtr(10),
42+
},
43+
}
44+
45+
resp, err := provider.ChatCompletion(ctx, key, request)
46+
if err != nil {
47+
t.Fatalf("ChatCompletion failed: %v", err)
48+
}
49+
50+
if resp == nil {
51+
t.Fatal("Expected non-nil response")
52+
}
53+
if len(resp.Choices) == 0 {
54+
t.Fatal("Expected at least one choice")
55+
}
56+
if resp.Choices[0].Message.Content == nil {
57+
t.Fatal("Expected message content")
58+
}
59+
t.Logf("Response: %s", *resp.Choices[0].Message.Content)
60+
}
61+
62+
func TestOpenAIChatCompletionWithTools(t *testing.T) {
63+
apiKey := os.Getenv("OPENAI_API_KEY")
64+
if apiKey == "" {
65+
t.Skip("OPENAI_API_KEY not set")
66+
}
67+
68+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
69+
NetworkConfig: schemas.NetworkConfig{
70+
BaseURL: "https://api.openai.com",
71+
DefaultRequestTimeoutInSeconds: 30,
72+
},
73+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
74+
Concurrency: 10,
75+
},
76+
}, newTestLogger())
77+
78+
ctx := context.Background()
79+
key := schemas.Key{Value: apiKey}
80+
81+
props := map[string]interface{}{
82+
"location": map[string]interface{}{
83+
"type": "string",
84+
"description": "The city name",
85+
},
86+
}
87+
88+
request := &schemas.BifrostChatRequest{
89+
Provider: schemas.OpenAI,
90+
Model: "gpt-3.5-turbo",
91+
Input: []schemas.ChatMessage{
92+
{
93+
Role: "user",
94+
Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")},
95+
},
96+
},
97+
Params: &schemas.ChatParameters{
98+
Temperature: float64Ptr(0.7),
99+
MaxCompletionTokens: intPtr(100),
100+
Tools: []schemas.ChatTool{
101+
{
102+
Type: schemas.ChatToolTypeFunction,
103+
Function: &schemas.ChatToolFunction{
104+
Name: "get_weather",
105+
Description: stringPtr("Get the current weather"),
106+
Parameters: &schemas.ToolFunctionParameters{
107+
Type: "object",
108+
Properties: &props,
109+
Required: []string{"location"},
110+
},
111+
},
112+
},
113+
},
114+
},
115+
}
116+
117+
resp, err := provider.ChatCompletion(ctx, key, request)
118+
if err != nil {
119+
t.Fatalf("ChatCompletion with tools failed: %v", err)
120+
}
121+
122+
if resp == nil {
123+
t.Fatal("Expected non-nil response")
124+
}
125+
if len(resp.Choices) == 0 {
126+
t.Fatal("Expected at least one choice")
127+
}
128+
t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls))
129+
}
130+
131+
func TestOpenAIChatCompletionStream(t *testing.T) {
132+
apiKey := os.Getenv("OPENAI_API_KEY")
133+
if apiKey == "" {
134+
t.Skip("OPENAI_API_KEY not set")
135+
}
136+
137+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
138+
NetworkConfig: schemas.NetworkConfig{
139+
BaseURL: "https://api.openai.com",
140+
DefaultRequestTimeoutInSeconds: 30,
141+
},
142+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
143+
Concurrency: 10,
144+
},
145+
}, newTestLogger())
146+
147+
ctx := context.Background()
148+
key := schemas.Key{Value: apiKey}
149+
150+
request := &schemas.BifrostChatRequest{
151+
Provider: schemas.OpenAI,
152+
Model: "gpt-3.5-turbo",
153+
Input: []schemas.ChatMessage{
154+
{
155+
Role: "user",
156+
Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 3")},
157+
},
158+
},
159+
Params: &schemas.ChatParameters{
160+
Temperature: float64Ptr(0.7),
161+
},
162+
}
163+
164+
streamChan, err := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request)
165+
if err != nil {
166+
t.Fatalf("ChatCompletionStream failed: %v", err)
167+
}
168+
169+
count := 0
170+
for chunk := range streamChan {
171+
if chunk.Error != nil {
172+
t.Fatalf("Stream error: %v", chunk.Error)
173+
}
174+
count++
175+
}
176+
177+
if count == 0 {
178+
t.Fatal("Expected at least one chunk")
179+
}
180+
t.Logf("Received %d chunks", count)
181+
}
182+
183+
func TestOpenAITextCompletion(t *testing.T) {
184+
apiKey := os.Getenv("OPENAI_API_KEY")
185+
if apiKey == "" {
186+
t.Skip("OPENAI_API_KEY not set")
187+
}
188+
189+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
190+
NetworkConfig: schemas.NetworkConfig{
191+
BaseURL: "https://api.openai.com",
192+
DefaultRequestTimeoutInSeconds: 30,
193+
},
194+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
195+
Concurrency: 10,
196+
},
197+
}, newTestLogger())
198+
199+
ctx := context.Background()
200+
key := schemas.Key{Value: apiKey}
201+
202+
request := &schemas.BifrostTextCompletionRequest{
203+
Provider: schemas.OpenAI,
204+
Model: "gpt-3.5-turbo-instruct",
205+
Input: &schemas.TextCompletionInput{PromptStr: stringPtr("Say hello")},
206+
Params: &schemas.TextCompletionParameters{
207+
Temperature: float64Ptr(0.7),
208+
MaxCompletionTokens: intPtr(10),
209+
},
210+
}
211+
212+
resp, err := provider.TextCompletion(ctx, key, request)
213+
if err != nil {
214+
t.Fatalf("TextCompletion failed: %v", err)
215+
}
216+
217+
if resp == nil {
218+
t.Fatal("Expected non-nil response")
219+
}
220+
if len(resp.Choices) == 0 {
221+
t.Fatal("Expected at least one choice")
222+
}
223+
t.Logf("Response: %s", resp.Choices[0].Text)
224+
}
225+
226+
func TestOpenAIEmbedding(t *testing.T) {
227+
apiKey := os.Getenv("OPENAI_API_KEY")
228+
if apiKey == "" {
229+
t.Skip("OPENAI_API_KEY not set")
230+
}
231+
232+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
233+
NetworkConfig: schemas.NetworkConfig{
234+
BaseURL: "https://api.openai.com",
235+
DefaultRequestTimeoutInSeconds: 30,
236+
},
237+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
238+
Concurrency: 10,
239+
},
240+
}, newTestLogger())
241+
242+
ctx := context.Background()
243+
key := schemas.Key{Value: apiKey}
244+
245+
request := &schemas.BifrostEmbeddingRequest{
246+
Provider: schemas.OpenAI,
247+
Model: "text-embedding-ada-002",
248+
Input: &schemas.EmbeddingInput{Texts: []string{"Hello world"}},
249+
}
250+
251+
resp, err := provider.Embedding(ctx, key, request)
252+
if err != nil {
253+
t.Fatalf("Embedding failed: %v", err)
254+
}
255+
256+
if resp == nil {
257+
t.Fatal("Expected non-nil response")
258+
}
259+
if len(resp.Data) == 0 {
260+
t.Fatal("Expected at least one embedding")
261+
}
262+
if resp.Data[0].Embedding.EmbeddingArray == nil || len(resp.Data[0].Embedding.EmbeddingArray) == 0 {
263+
t.Fatal("Expected non-empty embedding vector")
264+
}
265+
t.Logf("Embedding dimension: %d", len(resp.Data[0].Embedding.EmbeddingArray))
266+
}
267+
268+
func TestOpenAIGetProviderKey(t *testing.T) {
269+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
270+
NetworkConfig: schemas.NetworkConfig{
271+
BaseURL: "https://api.openai.com",
272+
},
273+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
274+
Concurrency: 10,
275+
},
276+
}, newTestLogger())
277+
278+
key := provider.GetProviderKey()
279+
if key != schemas.OpenAI {
280+
t.Errorf("Expected provider key %s, got %s", schemas.OpenAI, key)
281+
}
282+
}

core/providers/test_helpers.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package providers
2+
3+
import (
4+
"context"
5+
6+
"github.com/maximhq/bifrost/core/schemas"
7+
)
8+
9+
// newTestLogger creates a simple test logger
10+
func newTestLogger() schemas.Logger {
11+
return &testLogger{}
12+
}
13+
14+
type testLogger struct{}
15+
16+
func (l *testLogger) Debug(msg string, args ...interface{}) {}
17+
func (l *testLogger) Info(msg string, args ...interface{}) {}
18+
func (l *testLogger) Warn(msg string, args ...interface{}) {}
19+
func (l *testLogger) Error(msg string, args ...interface{}) {}
20+
func (l *testLogger) Fatal(msg string, args ...interface{}) {}
21+
func (l *testLogger) SetLevel(level schemas.LogLevel) {}
22+
func (l *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
23+
24+
// Helper functions for tests
25+
func stringPtr(s string) *string {
26+
return &s
27+
}
28+
29+
func float64Ptr(f float64) *float64 {
30+
return &f
31+
}
32+
33+
func intPtr(i int) *int {
34+
return &i
35+
}
36+
37+
// mockPostHookRunner is a mock implementation of PostHookRunner for testing
38+
func mockPostHookRunner(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
39+
return result, err
40+
}

core/schemas/bifrost.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields {
232232
return &r.ResponsesResponse.ExtraFields
233233
case r.ResponsesStreamResponse != nil:
234234
return &r.ResponsesStreamResponse.ExtraFields
235+
case r.EmbeddingResponse != nil:
236+
return &r.EmbeddingResponse.ExtraFields
235237
case r.SpeechResponse != nil:
236238
return &r.SpeechResponse.ExtraFields
237239
case r.SpeechStreamResponse != nil:

core/version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.2.12
1+
1.2.13

0 commit comments

Comments
 (0)