Skip to content

Commit 83263a8

Browse files
chore: tests fixes
1 parent 8fd0545 commit 83263a8

File tree

89 files changed

+1685
-751
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+1685
-751
lines changed

.github/workflows/scripts/release-core.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ fi
3232
# Building core
3333
go mod download
3434
go build ./...
35-
go test ./...
3635
cd ..
3736
echo "✅ Core build validation successful"
3837

38+
# Run core provider tests
39+
echo "🔧 Running core provider tests..."
40+
cd tests/core-providers
41+
go test -v ./...
42+
cd ../..
3943

4044
# Capturing changelog
4145
CHANGELOG_BODY=$(cat core/changelog.md)

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ go.work.sum
3333
# Sqlite DBs
3434
*.db
3535
*.db-shm
36-
*.db-wal
36+
*.db-wal
37+
38+
.claude

core/bifrost.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,11 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
12291229
primaryResult, primaryErr := bifrost.tryRequest(ctx, req)
12301230

12311231
if primaryErr != nil {
1232-
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", provider, model, primaryErr))
1232+
if primaryErr.Error != nil {
1233+
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %s", provider, model, primaryErr.Error.Message))
1234+
} else {
1235+
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", provider, model, primaryErr))
1236+
}
12331237
if len(fallbacks) > 0 {
12341238
bifrost.logger.Debug(fmt.Sprintf("Check if we should try %d fallbacks", len(fallbacks)))
12351239
}
@@ -1629,7 +1633,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
16291633
time.Sleep(backoff)
16301634
}
16311635

1632-
bifrost.logger.Debug("attempting request for provider %s", provider.GetProviderKey())
1636+
bifrost.logger.Debug("attempting %s request for provider %s", req.RequestType, provider.GetProviderKey())
16331637

16341638
// Attempt the request
16351639
if IsStreamRequestType(req.RequestType) {
@@ -1644,7 +1648,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
16441648
}
16451649
}
16461650

1647-
bifrost.logger.Debug("request for provider %s completed", provider.GetProviderKey())
1651+
bifrost.logger.Debug("request %s for provider %s completed", req.RequestType, provider.GetProviderKey())
16481652

16491653
// Check if successful or if we should retry
16501654
if bifrostError == nil ||

core/changelog.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
<!-- The pattern we follow here is to keep the changelog for the latest version -->
22
<!-- Old changelogs are automatically attached to the GitHub releases -->
33

4-
- fix: openai specific parameters filtered for openai compatibile providers
5-
- fix: error response unmarshalling for gemini provider
6-
- BREAKING FIX: json_schema field correctly renamed to schema; ResponsesTextConfigFormatJSONSchema restructured
4+
- bug: fixed embedding request not being handled in `GetExtraFields()` method of `BifrostResponse`
5+
- fix: added latency calculation for vertex native requests
6+
- feat: added cached tokens and reasoning tokens to the usage metadata for chat completions
7+
- feat: added global region support for vertex API
8+
- fix: added filter for extra fields in chat completions request for Mistral provider

core/providers/gemini.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner
381381

382382
scanner := bufio.NewScanner(resp.Body)
383383
// Increase buffer size to handle large chunks (especially for audio data)
384-
buf := make([]byte, 0, 256*1024) // 256KB buffer
385-
scanner.Buffer(buf, 1024*1024) // Allow up to 1MB tokens
384+
buf := make([]byte, 0, 1024*1024) // 1MB initial buffer
385+
scanner.Buffer(buf, 10*1024*1024) // Allow up to 10MB tokens
386386
chunkIndex := -1
387387
usage := &schemas.SpeechUsage{}
388388
startTime := time.Now()
@@ -658,6 +658,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo
658658
defer resp.Body.Close()
659659

660660
scanner := bufio.NewScanner(resp.Body)
661+
// Increase buffer size to handle large chunks (especially for audio data)
662+
buf := make([]byte, 0, 1024*1024) // 1MB initial buffer
663+
scanner.Buffer(buf, 10*1024*1024) // Allow up to 10MB tokens
661664
chunkIndex := -1
662665
usage := &schemas.TranscriptionUsage{}
663666
startTime := time.Now()
@@ -674,8 +677,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo
674677
}
675678
var jsonData string
676679
// Parse SSE data
677-
if strings.HasPrefix(line, "data: ") {
678-
jsonData = strings.TrimPrefix(line, "data: ")
680+
if after, ok := strings.CutPrefix(line, "data: "); ok {
681+
jsonData = after
679682
} else {
680683
// Handle raw JSON errors (without "data: " prefix)
681684
jsonData = line

core/providers/groq.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,18 @@ func (provider *GroqProvider) TextCompletionStream(ctx context.Context, postHook
124124
responseChan <- response
125125
continue
126126
}
127-
response.ToTextCompletionResponse()
128-
if response.BifrostTextCompletionResponse != nil {
129-
response.BifrostTextCompletionResponse.ExtraFields.RequestType = schemas.TextCompletionRequest
130-
response.BifrostTextCompletionResponse.ExtraFields.Provider = provider.GetProviderKey()
131-
response.BifrostTextCompletionResponse.ExtraFields.ModelRequested = request.Model
127+
if response.BifrostChatResponse != nil {
128+
textCompletionResponse := response.BifrostChatResponse.ToTextCompletionResponse()
129+
if textCompletionResponse != nil {
130+
textCompletionResponse.ExtraFields.RequestType = schemas.TextCompletionRequest
131+
textCompletionResponse.ExtraFields.Provider = provider.GetProviderKey()
132+
textCompletionResponse.ExtraFields.ModelRequested = request.Model
133+
134+
responseChan <- &schemas.BifrostStream{
135+
BifrostTextCompletionResponse: textCompletionResponse,
136+
}
137+
}
132138
}
133-
responseChan <- response
134139
}
135140
}()
136141
return responseChan, nil

core/providers/test_helpers.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package providers
2+
3+
import (
4+
"context"
5+
6+
"github.com/maximhq/bifrost/core/schemas"
7+
)
8+
9+
// newTestLogger creates a simple test logger
10+
func newTestLogger() schemas.Logger {
11+
return &testLogger{}
12+
}
13+
14+
type testLogger struct{}
15+
16+
func (l *testLogger) Debug(msg string, args ...interface{}) {}
17+
func (l *testLogger) Info(msg string, args ...interface{}) {}
18+
func (l *testLogger) Warn(msg string, args ...interface{}) {}
19+
func (l *testLogger) Error(msg string, args ...interface{}) {}
20+
func (l *testLogger) Fatal(msg string, args ...interface{}) {}
21+
func (l *testLogger) SetLevel(level schemas.LogLevel) {}
22+
func (l *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
23+
24+
// Helper functions for tests
25+
func stringPtr(s string) *string {
26+
return &s
27+
}
28+
29+
func float64Ptr(f float64) *float64 {
30+
return &f
31+
}
32+
33+
func intPtr(i int) *int {
34+
return &i
35+
}
36+
37+
// mockPostHookRunner is a mock implementation of PostHookRunner for testing
38+
func mockPostHookRunner(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
39+
return result, err
40+
}

core/providers/vertex.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,19 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
203203
return nil, newConfigurationError("region is not set in key config", schemas.Vertex)
204204
}
205205

206-
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
207-
206+
var url string
208207
if strings.Contains(request.Model, "claude") {
209-
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model)
208+
if region == "global" {
209+
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model)
210+
} else {
211+
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model)
212+
}
213+
} else {
214+
if region == "global" {
215+
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/openapi/chat/completions", projectID)
216+
} else {
217+
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
218+
}
210219
}
211220

212221
// Create request
@@ -286,12 +295,19 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
286295
}
287296

288297
var openAIErr schemas.BifrostError
289-
var vertexErr []VertexError
290298

299+
var vertexErr []VertexError
291300
if err := sonic.Unmarshal(body, &openAIErr); err != nil {
292301
// Try Vertex error format if OpenAI format fails
293302
if err := sonic.Unmarshal(body, &vertexErr); err != nil {
294-
return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex)
303+
304+
//try with single Vertex error format
305+
var vertexErr VertexError
306+
if err := sonic.Unmarshal(body, &vertexErr); err != nil {
307+
return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex)
308+
}
309+
310+
return nil, newProviderAPIError(vertexErr.Error.Message, nil, resp.StatusCode, schemas.Vertex, nil, nil)
295311
}
296312

297313
if len(vertexErr) > 0 {
@@ -395,7 +411,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo
395411
delete(requestBody, "model")
396412
delete(requestBody, "region")
397413

398-
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model)
414+
var url string
415+
if region == "global" {
416+
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model)
417+
} else {
418+
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model)
419+
}
399420

400421
// Prepare headers for Vertex Anthropic
401422
headers := map[string]string{
@@ -418,7 +439,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo
418439
provider.logger,
419440
)
420441
} else {
421-
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
442+
var url string
443+
if region == "global" {
444+
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/openapi/chat/completions", projectID)
445+
} else {
446+
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
447+
}
422448
authHeader := map[string]string{}
423449
if key.Value != "" {
424450
authHeader["Authorization"] = "Bearer " + key.Value

core/schemas/bifrost.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields {
232232
return &r.ResponsesResponse.ExtraFields
233233
case r.ResponsesStreamResponse != nil:
234234
return &r.ResponsesStreamResponse.ExtraFields
235+
case r.EmbeddingResponse != nil:
236+
return &r.EmbeddingResponse.ExtraFields
235237
case r.SpeechResponse != nil:
236238
return &r.SpeechResponse.ExtraFields
237239
case r.SpeechStreamResponse != nil:

core/schemas/providers/gemini/chat.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ func (request *GeminiGenerationRequest) ToBifrostChatRequest() *schemas.BifrostC
3131

3232
allGenAiMessages := []Content{}
3333
if request.SystemInstruction != nil {
34-
allGenAiMessages = append(allGenAiMessages, request.SystemInstruction.ToGenAIContent())
34+
allGenAiMessages = append(allGenAiMessages, *request.SystemInstruction)
3535
}
3636
for _, content := range request.Contents {
37-
allGenAiMessages = append(allGenAiMessages, content.ToGenAIContent())
37+
allGenAiMessages = append(allGenAiMessages, content)
3838
}
3939

4040
for _, content := range allGenAiMessages {

0 commit comments

Comments
 (0)