Skip to content

Commit 7c005f3

Browse files
chore: tests fixes
1 parent 8fd0545 commit 7c005f3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1728
-379
lines changed

.github/workflows/scripts/release-core.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ fi
3232
# Building core
3333
go mod download
3434
go build ./...
35-
go test ./...
3635
cd ..
3736
echo "✅ Core build validation successful"
3837

38+
# Run core provider tests
39+
echo "🔧 Running core provider tests..."
40+
cd tests/core-providers
41+
go test -v ./...
42+
cd ../..
3943

4044
# Capturing changelog
4145
CHANGELOG_BODY=$(cat core/changelog.md)

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ go.work.sum
3333
# Sqlite DBs
3434
*.db
3535
*.db-shm
36-
*.db-wal
36+
*.db-wal
37+
38+
.claude

core/bifrost.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,11 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
12291229
primaryResult, primaryErr := bifrost.tryRequest(ctx, req)
12301230

12311231
if primaryErr != nil {
1232-
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", provider, model, primaryErr))
1232+
if primaryErr.Error != nil {
1233+
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %s", provider, model, primaryErr.Error.Message))
1234+
} else {
1235+
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", provider, model, primaryErr))
1236+
}
12331237
if len(fallbacks) > 0 {
12341238
bifrost.logger.Debug(fmt.Sprintf("Check if we should try %d fallbacks", len(fallbacks)))
12351239
}
@@ -1629,7 +1633,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
16291633
time.Sleep(backoff)
16301634
}
16311635

1632-
bifrost.logger.Debug("attempting request for provider %s", provider.GetProviderKey())
1636+
bifrost.logger.Debug("attempting %s request for provider %s", req.RequestType, provider.GetProviderKey())
16331637

16341638
// Attempt the request
16351639
if IsStreamRequestType(req.RequestType) {
@@ -1644,7 +1648,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
16441648
}
16451649
}
16461650

1647-
bifrost.logger.Debug("request for provider %s completed", provider.GetProviderKey())
1651+
bifrost.logger.Debug("request %s for provider %s completed", req.RequestType, provider.GetProviderKey())
16481652

16491653
// Check if successful or if we should retry
16501654
if bifrostError == nil ||

core/changelog.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<!-- The pattern we follow here is to keep the changelog for the latest version -->
22
<!-- Old changelogs are automatically attached to the GitHub releases -->
33

4-
- fix: openai specific parameters filtered for openai compatibile providers
5-
- fix: error response unmarshalling for gemini provider
6-
- BREAKING FIX: json_schema field correctly renamed to schema; ResponsesTextConfigFormatJSONSchema restructured
4+
- bug: fixed embedding request not being handled in `GetExtraFields()` method of `BifrostResponse`
5+
- fix: added latency calculation for vertex native requests
6+
- feat: added cached tokens and reasoning tokens to the usage metadata for chat completions

core/providers/groq.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,18 @@ func (provider *GroqProvider) TextCompletionStream(ctx context.Context, postHook
124124
responseChan <- response
125125
continue
126126
}
127-
response.ToTextCompletionResponse()
128-
if response.BifrostTextCompletionResponse != nil {
129-
response.BifrostTextCompletionResponse.ExtraFields.RequestType = schemas.TextCompletionRequest
130-
response.BifrostTextCompletionResponse.ExtraFields.Provider = provider.GetProviderKey()
131-
response.BifrostTextCompletionResponse.ExtraFields.ModelRequested = request.Model
127+
if response.BifrostChatResponse != nil {
128+
textCompletionResponse := response.BifrostChatResponse.ToTextCompletionResponse()
129+
if textCompletionResponse != nil {
130+
textCompletionResponse.ExtraFields.RequestType = schemas.TextCompletionRequest
131+
textCompletionResponse.ExtraFields.Provider = provider.GetProviderKey()
132+
textCompletionResponse.ExtraFields.ModelRequested = request.Model
133+
134+
responseChan <- &schemas.BifrostStream{
135+
BifrostTextCompletionResponse: textCompletionResponse,
136+
}
137+
}
132138
}
133-
responseChan <- response
134139
}
135140
}()
136141
return responseChan, nil

core/providers/openai_test.go

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
package providers
2+
3+
import (
4+
"context"
5+
"os"
6+
"testing"
7+
8+
"github.com/maximhq/bifrost/core/schemas"
9+
)
10+
11+
func TestOpenAIChatCompletion(t *testing.T) {
12+
apiKey := os.Getenv("OPENAI_API_KEY")
13+
if apiKey == "" {
14+
t.Skip("OPENAI_API_KEY not set")
15+
}
16+
17+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
18+
NetworkConfig: schemas.NetworkConfig{
19+
BaseURL: "https://api.openai.com",
20+
DefaultRequestTimeoutInSeconds: 30,
21+
},
22+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
23+
Concurrency: 10,
24+
},
25+
}, newTestLogger())
26+
27+
ctx := context.Background()
28+
key := schemas.Key{Value: apiKey}
29+
30+
request := &schemas.BifrostChatRequest{
31+
Provider: schemas.OpenAI,
32+
Model: "gpt-3.5-turbo",
33+
Input: []schemas.ChatMessage{
34+
{
35+
Role: "user",
36+
Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")},
37+
},
38+
},
39+
Params: &schemas.ChatParameters{
40+
Temperature: float64Ptr(0.7),
41+
MaxCompletionTokens: intPtr(10),
42+
},
43+
}
44+
45+
resp, err := provider.ChatCompletion(ctx, key, request)
46+
if err != nil {
47+
t.Fatalf("ChatCompletion failed: %v", err)
48+
}
49+
50+
if resp == nil {
51+
t.Fatal("Expected non-nil response")
52+
}
53+
if len(resp.Choices) == 0 {
54+
t.Fatal("Expected at least one choice")
55+
}
56+
if resp.Choices[0].Message.Content == nil {
57+
t.Fatal("Expected message content")
58+
}
59+
t.Logf("Response: %s", *resp.Choices[0].Message.Content)
60+
}
61+
62+
func TestOpenAIChatCompletionWithTools(t *testing.T) {
63+
apiKey := os.Getenv("OPENAI_API_KEY")
64+
if apiKey == "" {
65+
t.Skip("OPENAI_API_KEY not set")
66+
}
67+
68+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
69+
NetworkConfig: schemas.NetworkConfig{
70+
BaseURL: "https://api.openai.com",
71+
DefaultRequestTimeoutInSeconds: 30,
72+
},
73+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
74+
Concurrency: 10,
75+
},
76+
}, newTestLogger())
77+
78+
ctx := context.Background()
79+
key := schemas.Key{Value: apiKey}
80+
81+
props := map[string]interface{}{
82+
"location": map[string]interface{}{
83+
"type": "string",
84+
"description": "The city name",
85+
},
86+
}
87+
88+
request := &schemas.BifrostChatRequest{
89+
Provider: schemas.OpenAI,
90+
Model: "gpt-3.5-turbo",
91+
Input: []schemas.ChatMessage{
92+
{
93+
Role: "user",
94+
Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")},
95+
},
96+
},
97+
Params: &schemas.ChatParameters{
98+
Temperature: float64Ptr(0.7),
99+
MaxCompletionTokens: intPtr(100),
100+
Tools: []schemas.ChatTool{
101+
{
102+
Type: schemas.ChatToolTypeFunction,
103+
Function: &schemas.ChatToolFunction{
104+
Name: "get_weather",
105+
Description: stringPtr("Get the current weather"),
106+
Parameters: &schemas.ToolFunctionParameters{
107+
Type: "object",
108+
Properties: &props,
109+
Required: []string{"location"},
110+
},
111+
},
112+
},
113+
},
114+
},
115+
}
116+
117+
resp, err := provider.ChatCompletion(ctx, key, request)
118+
if err != nil {
119+
t.Fatalf("ChatCompletion with tools failed: %v", err)
120+
}
121+
122+
if resp == nil {
123+
t.Fatal("Expected non-nil response")
124+
}
125+
if len(resp.Choices) == 0 {
126+
t.Fatal("Expected at least one choice")
127+
}
128+
t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls))
129+
}
130+
131+
func TestOpenAIChatCompletionStream(t *testing.T) {
132+
apiKey := os.Getenv("OPENAI_API_KEY")
133+
if apiKey == "" {
134+
t.Skip("OPENAI_API_KEY not set")
135+
}
136+
137+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
138+
NetworkConfig: schemas.NetworkConfig{
139+
BaseURL: "https://api.openai.com",
140+
DefaultRequestTimeoutInSeconds: 30,
141+
},
142+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
143+
Concurrency: 10,
144+
},
145+
}, newTestLogger())
146+
147+
ctx := context.Background()
148+
key := schemas.Key{Value: apiKey}
149+
150+
request := &schemas.BifrostChatRequest{
151+
Provider: schemas.OpenAI,
152+
Model: "gpt-3.5-turbo",
153+
Input: []schemas.ChatMessage{
154+
{
155+
Role: "user",
156+
Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 3")},
157+
},
158+
},
159+
Params: &schemas.ChatParameters{
160+
Temperature: float64Ptr(0.7),
161+
},
162+
}
163+
164+
streamChan, err := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request)
165+
if err != nil {
166+
t.Fatalf("ChatCompletionStream failed: %v", err)
167+
}
168+
169+
count := 0
170+
for chunk := range streamChan {
171+
if chunk.Error != nil {
172+
t.Fatalf("Stream error: %v", chunk.Error)
173+
}
174+
count++
175+
}
176+
177+
if count == 0 {
178+
t.Fatal("Expected at least one chunk")
179+
}
180+
t.Logf("Received %d chunks", count)
181+
}
182+
183+
func TestOpenAITextCompletion(t *testing.T) {
184+
apiKey := os.Getenv("OPENAI_API_KEY")
185+
if apiKey == "" {
186+
t.Skip("OPENAI_API_KEY not set")
187+
}
188+
189+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
190+
NetworkConfig: schemas.NetworkConfig{
191+
BaseURL: "https://api.openai.com",
192+
DefaultRequestTimeoutInSeconds: 30,
193+
},
194+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
195+
Concurrency: 10,
196+
},
197+
}, newTestLogger())
198+
199+
ctx := context.Background()
200+
key := schemas.Key{Value: apiKey}
201+
202+
request := &schemas.BifrostTextCompletionRequest{
203+
Provider: schemas.OpenAI,
204+
Model: "gpt-3.5-turbo-instruct",
205+
Input: &schemas.TextCompletionInput{PromptStr: stringPtr("Say hello")},
206+
Params: &schemas.TextCompletionParameters{
207+
Temperature: float64Ptr(0.7),
208+
MaxTokens: intPtr(10),
209+
},
210+
}
211+
212+
resp, err := provider.TextCompletion(ctx, key, request)
213+
if err != nil {
214+
t.Fatalf("TextCompletion failed: %v", err)
215+
}
216+
217+
if resp == nil {
218+
t.Fatal("Expected non-nil response")
219+
}
220+
if len(resp.Choices) == 0 {
221+
t.Fatal("Expected at least one choice")
222+
}
223+
t.Logf("Response: %s", resp.Choices[0].Text)
224+
}
225+
226+
func TestOpenAIEmbedding(t *testing.T) {
227+
apiKey := os.Getenv("OPENAI_API_KEY")
228+
if apiKey == "" {
229+
t.Skip("OPENAI_API_KEY not set")
230+
}
231+
232+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
233+
NetworkConfig: schemas.NetworkConfig{
234+
BaseURL: "https://api.openai.com",
235+
DefaultRequestTimeoutInSeconds: 30,
236+
},
237+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
238+
Concurrency: 10,
239+
},
240+
}, newTestLogger())
241+
242+
ctx := context.Background()
243+
key := schemas.Key{Value: apiKey}
244+
245+
request := &schemas.BifrostEmbeddingRequest{
246+
Provider: schemas.OpenAI,
247+
Model: "text-embedding-ada-002",
248+
Input: &schemas.EmbeddingInput{Texts: []string{"Hello world"}},
249+
}
250+
251+
resp, err := provider.Embedding(ctx, key, request)
252+
if err != nil {
253+
t.Fatalf("Embedding failed: %v", err)
254+
}
255+
256+
if resp == nil {
257+
t.Fatal("Expected non-nil response")
258+
}
259+
if len(resp.Data) == 0 {
260+
t.Fatal("Expected at least one embedding")
261+
}
262+
if resp.Data[0].Embedding.EmbeddingArray == nil || len(resp.Data[0].Embedding.EmbeddingArray) == 0 {
263+
t.Fatal("Expected non-empty embedding vector")
264+
}
265+
t.Logf("Embedding dimension: %d", len(resp.Data[0].Embedding.EmbeddingArray))
266+
}
267+
268+
func TestOpenAIGetProviderKey(t *testing.T) {
269+
provider := NewOpenAIProvider(&schemas.ProviderConfig{
270+
NetworkConfig: schemas.NetworkConfig{
271+
BaseURL: "https://api.openai.com",
272+
},
273+
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
274+
Concurrency: 10,
275+
},
276+
}, newTestLogger())
277+
278+
key := provider.GetProviderKey()
279+
if key != schemas.OpenAI {
280+
t.Errorf("Expected provider key %s, got %s", schemas.OpenAI, key)
281+
}
282+
}

0 commit comments

Comments
 (0)