Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions core/providers/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -100,7 +101,7 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas.
jsonBody, err := sonic.Marshal(reqBody)
if err != nil {
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName)
}
}

// Create request
req := fasthttp.AcquireRequest()
Expand All @@ -126,19 +127,7 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas.

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
var errorResp []struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details []struct {
Type string `json:"@type"`
FieldViolations []struct {
Description string `json:"description"`
} `json:"fieldViolations"`
} `json:"details"`
} `json:"error"`
}
var errorResp []gemini.GeminiGenerationError

bifrostErr := handleProviderAPIError(resp, &errorResp)
errorMessage := ""
Expand Down Expand Up @@ -907,29 +896,50 @@ func parseStreamGeminiError(providerName schemas.ModelProvider, resp *http.Respo
}

// Try to parse as JSON first
var errorResp map[string]interface{}
var errorResp gemini.GeminiGenerationError
if err := sonic.Unmarshal(body, &errorResp); err == nil {
// Successfully parsed as JSON
return newBifrostOperationError(fmt.Sprintf("Gemini streaming error: %v", errorResp), fmt.Errorf("HTTP %d", resp.StatusCode), providerName)
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(resp.StatusCode),
Error: &schemas.ErrorField{
Code: schemas.Ptr(strconv.Itoa(errorResp.Error.Code)),
Message: errorResp.Error.Message,
},
}
return bifrostErr
}

// If JSON parsing fails, treat as plain text
bodyStr := string(body)
if bodyStr == "" {
bodyStr = "empty response body"
// If JSON parsing fails, use the raw response body
var rawResponse interface{}
if err := sonic.Unmarshal(body, &rawResponse); err != nil {
return newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName)
}

return newBifrostOperationError(fmt.Sprintf("Gemini streaming error (HTTP %d): %s", resp.StatusCode, bodyStr), fmt.Errorf("HTTP %d", resp.StatusCode), providerName)
return newBifrostOperationError(fmt.Sprintf("Gemini streaming error (HTTP %d): %v", resp.StatusCode, rawResponse), fmt.Errorf("HTTP %d", resp.StatusCode), providerName)
}

// parseGeminiError parses Gemini error responses
func parseGeminiError(providerName schemas.ModelProvider, resp *fasthttp.Response) *schemas.BifrostError {
var errorResp map[string]interface{}
body := resp.Body()

if err := sonic.Unmarshal(body, &errorResp); err != nil {
// Try to parse as JSON first
var errorResp gemini.GeminiGenerationError
if err := sonic.Unmarshal(body, &errorResp); err == nil {
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(resp.StatusCode()),
Error: &schemas.ErrorField{
Code: schemas.Ptr(strconv.Itoa(errorResp.Error.Code)),
Message: errorResp.Error.Message,
},
}
return bifrostErr
}

var rawResponse map[string]interface{}
if err := sonic.Unmarshal(body, &rawResponse); err != nil {
return newBifrostOperationError("failed to parse error response", err, providerName)
}

return newBifrostOperationError(fmt.Sprintf("Gemini error: %v", errorResp), fmt.Errorf("HTTP %d", resp.StatusCode()), providerName)
return newBifrostOperationError(fmt.Sprintf("Gemini error: %v", rawResponse), fmt.Errorf("HTTP %d", resp.StatusCode()), providerName)
}
28 changes: 15 additions & 13 deletions core/providers/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/http"
"strings"
"sync"
"time"

"golang.org/x/oauth2/google"

Expand Down Expand Up @@ -245,6 +246,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex)
}

startTime := time.Now()

// Make request
resp, err := client.Do(req)
if err != nil {
Expand All @@ -267,6 +270,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
}
defer resp.Body.Close()

latency := time.Since(startTime)

// Handle error response
// Read response body
body, err := io.ReadAll(resp.Body)
Expand Down Expand Up @@ -314,6 +319,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
RequestType: schemas.ChatCompletionRequest,
Provider: schemas.Vertex,
ModelRequested: request.Model,
Latency: latency.Milliseconds(),
}

if provider.sendBackRawResponse {
Expand All @@ -322,10 +328,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.

return response, nil
} else {
// Pre-allocate response structs from pools
// response := acquireOpenAIResponse()
response := &schemas.BifrostChatResponse{}
// defer releaseOpenAIResponse(response)

// Use enhanced response handler with pre-allocated response
rawResponse, bifrostErr := handleProviderResponse(body, response, provider.sendBackRawResponse)
Expand All @@ -336,6 +339,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
response.ExtraFields.Provider = schemas.Vertex
response.ExtraFields.ModelRequested = request.Model
response.ExtraFields.Latency = latency.Milliseconds()

if provider.sendBackRawResponse {
response.ExtraFields.RawResponse = rawResponse
Expand Down Expand Up @@ -484,22 +488,15 @@ func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key,
return nil, newConfigurationError("embedding input texts are empty", schemas.Vertex)
}

// All Vertex AI embedding models use the same native Vertex embedding API
return provider.handleVertexEmbedding(ctx, request.Model, key, reqBody, request.Params)
}

// handleVertexEmbedding handles embedding requests using Vertex's native embedding API
// This is used for all Vertex AI embedding models as they all use the same response format
func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model string, key schemas.Key, vertexReq *vertex.VertexEmbeddingRequest, params *schemas.EmbeddingParameters) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
// Use the typed request directly
jsonBody, err := sonic.Marshal(vertexReq)
jsonBody, err := sonic.Marshal(reqBody)
if err != nil {
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex)
}

// Build the native Vertex embedding API endpoint
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict",
key.VertexKeyConfig.Region, key.VertexKeyConfig.ProjectID, key.VertexKeyConfig.Region, model)
key.VertexKeyConfig.Region, key.VertexKeyConfig.ProjectID, key.VertexKeyConfig.Region, request.Model)

// Create request
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
Expand Down Expand Up @@ -532,6 +529,8 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model
return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex)
}

startTime := time.Now()

// Make request
resp, err := client.Do(req)
if err != nil {
Expand All @@ -554,6 +553,8 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model
}
defer resp.Body.Close()

latency := time.Since(startTime)

// Handle error response
body, err := io.ReadAll(resp.Body)
if err != nil {
Expand Down Expand Up @@ -598,8 +599,9 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model

// Set ExtraFields
bifrostResponse.ExtraFields.Provider = schemas.Vertex
bifrostResponse.ExtraFields.ModelRequested = model
bifrostResponse.ExtraFields.ModelRequested = request.Model
bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()

// Set raw response if enabled
if provider.sendBackRawResponse {
Expand Down
20 changes: 17 additions & 3 deletions core/schemas/chatcompletions.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,21 @@ type ContentLogProb struct {

// BifrostLLMUsage represents token usage information
type BifrostLLMUsage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens"`
PromptTokens int `json:"prompt_tokens,omitempty"`
PromptTokensDetails *ChatPromptTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
CompletionTokensDetails *ChatCompletionTokensDetails `json:"completion_tokens_details,omitempty"`
TotalTokens int `json:"total_tokens"`
}

type ChatPromptTokensDetails struct {
AudioTokens int `json:"audio_tokens,omitempty"`
CachedTokens int `json:"cached_tokens,omitempty"`
}

type ChatCompletionTokensDetails struct {
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
AudioTokens int `json:"audio_tokens,omitempty"`
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
}
84 changes: 61 additions & 23 deletions core/schemas/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,64 @@ func ToChatMessages(rms []ResponsesMessage) []ChatMessage {
return chatMessages
}

func (cu *BifrostLLMUsage) ToResponsesResponseUsage() *ResponsesResponseUsage {
if cu == nil {
return nil
}

usage := &ResponsesResponseUsage{
InputTokens: cu.PromptTokens,
OutputTokens: cu.CompletionTokens,
TotalTokens: cu.TotalTokens,
}

if cu.PromptTokensDetails != nil {
usage.InputTokensDetails = &ResponsesResponseInputTokens{
AudioTokens: cu.PromptTokensDetails.AudioTokens,
CachedTokens: cu.PromptTokensDetails.CachedTokens,
}
}
if cu.CompletionTokensDetails != nil {
usage.OutputTokensDetails = &ResponsesResponseOutputTokens{
AcceptedPredictionTokens: cu.CompletionTokensDetails.AcceptedPredictionTokens,
AudioTokens: cu.CompletionTokensDetails.AudioTokens,
ReasoningTokens: cu.CompletionTokensDetails.ReasoningTokens,
RejectedPredictionTokens: cu.CompletionTokensDetails.RejectedPredictionTokens,
}
}

return usage
}

func (ru *ResponsesResponseUsage) ToBifrostLLMUsage() *BifrostLLMUsage {
if ru == nil {
return nil
}

usage := &BifrostLLMUsage{
PromptTokens: ru.InputTokens,
CompletionTokens: ru.OutputTokens,
TotalTokens: ru.TotalTokens,
}

if ru.InputTokensDetails != nil {
usage.PromptTokensDetails = &ChatPromptTokensDetails{
AudioTokens: ru.InputTokensDetails.AudioTokens,
CachedTokens: ru.InputTokensDetails.CachedTokens,
}
}
if ru.OutputTokensDetails != nil {
usage.CompletionTokensDetails = &ChatCompletionTokensDetails{
AcceptedPredictionTokens: ru.OutputTokensDetails.AcceptedPredictionTokens,
AudioTokens: ru.OutputTokensDetails.AudioTokens,
ReasoningTokens: ru.OutputTokensDetails.ReasoningTokens,
RejectedPredictionTokens: ru.OutputTokensDetails.RejectedPredictionTokens,
}
}

return usage
}

// =============================================================================
// REQUEST CONVERSION METHODS
// =============================================================================
Expand Down Expand Up @@ -805,15 +863,7 @@ func (cr *BifrostChatResponse) ToBifrostResponsesResponse() *BifrostResponsesRes

// Convert Usage if needed
if cr.Usage != nil {
responsesResp.Usage = &ResponsesResponseUsage{
InputTokens: cr.Usage.PromptTokens,
OutputTokens: cr.Usage.CompletionTokens,
TotalTokens: cr.Usage.TotalTokens,
}

if responsesResp.Usage.TotalTokens == 0 {
responsesResp.Usage.TotalTokens = cr.Usage.PromptTokens + cr.Usage.CompletionTokens
}
responsesResp.Usage = cr.Usage.ToResponsesResponseUsage()
}

// Copy other relevant fields
Expand Down Expand Up @@ -859,15 +909,7 @@ func (responsesResp *BifrostResponsesResponse) ToBifrostChatResponse() *BifrostC
// Convert Usage if needed
if responsesResp.Usage != nil {
// Map Responses usage to Chat usage
chatResp.Usage = &BifrostLLMUsage{
PromptTokens: responsesResp.Usage.InputTokens,
CompletionTokens: responsesResp.Usage.OutputTokens,
TotalTokens: responsesResp.Usage.TotalTokens,
}

if chatResp.Usage.TotalTokens == 0 {
chatResp.Usage.TotalTokens = chatResp.Usage.PromptTokens + chatResp.Usage.CompletionTokens
}
chatResp.Usage = responsesResp.Usage.ToBifrostLLMUsage()
}

// Copy other relevant fields
Expand Down Expand Up @@ -976,11 +1018,7 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse() *BifrostRespon
// Add usage information if present in the response
if cr.Usage != nil {
streamResp.Response = &BifrostResponsesResponse{
Usage: &ResponsesResponseUsage{
InputTokens: cr.Usage.PromptTokens,
OutputTokens: cr.Usage.CompletionTokens,
TotalTokens: cr.Usage.TotalTokens,
},
Usage: cr.Usage.ToResponsesResponseUsage(),
}
}
} else {
Expand Down
11 changes: 10 additions & 1 deletion core/schemas/providers/anthropic/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"time"

"github.com/maximhq/bifrost/core/schemas"
)

Expand Down Expand Up @@ -350,7 +351,10 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse() *schemas.Bifro
// Convert usage information
if response.Usage != nil {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: response.Usage.InputTokens,
PromptTokens: response.Usage.InputTokens,
PromptTokensDetails: &schemas.ChatPromptTokensDetails{
CachedTokens: response.Usage.CacheCreationInputTokens + response.Usage.CacheReadInputTokens,
},
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
}
Expand Down Expand Up @@ -613,6 +617,11 @@ func ToAnthropicChatCompletionResponse(bifrostResp *schemas.BifrostChatResponse)
InputTokens: bifrostResp.Usage.PromptTokens,
OutputTokens: bifrostResp.Usage.CompletionTokens,
}

//NOTE: We cannot segregate between cache creation and cache read tokens, so we will use the total cached tokens as the cache read tokens
if bifrostResp.Usage.PromptTokensDetails != nil && bifrostResp.Usage.PromptTokensDetails.CachedTokens > 0 {
anthropicResp.Usage.CacheReadInputTokens = bifrostResp.Usage.PromptTokensDetails.CachedTokens
}
}

// Convert choices to content
Expand Down
5 changes: 5 additions & 0 deletions core/schemas/providers/cohere/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ func (response *CohereChatResponse) ToBifrostChatResponse() *schemas.BifrostChat
if response.Usage.Tokens.OutputTokens != nil {
usage.CompletionTokens = int(*response.Usage.Tokens.OutputTokens)
}
if response.Usage.CachedTokens != nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
CachedTokens: int(*response.Usage.CachedTokens),
}
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}

Expand Down
Loading