Skip to content

Commit fbdd429

Browse files
feat: enhacements token details for chat completions and added latency calculation in vertex
1 parent b0c2195 commit fbdd429

File tree

13 files changed

+211
-57
lines changed

13 files changed

+211
-57
lines changed

core/providers/vertex.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"net/http"
1414
"strings"
1515
"sync"
16+
"time"
1617

1718
"golang.org/x/oauth2/google"
1819

@@ -245,6 +246,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
245246
return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex)
246247
}
247248

249+
startTime := time.Now()
250+
248251
// Make request
249252
resp, err := client.Do(req)
250253
if err != nil {
@@ -267,6 +270,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
267270
}
268271
defer resp.Body.Close()
269272

273+
latency := time.Since(startTime)
274+
270275
// Handle error response
271276
// Read response body
272277
body, err := io.ReadAll(resp.Body)
@@ -314,6 +319,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
314319
RequestType: schemas.ChatCompletionRequest,
315320
Provider: schemas.Vertex,
316321
ModelRequested: request.Model,
322+
Latency: latency.Milliseconds(),
317323
}
318324

319325
if provider.sendBackRawResponse {
@@ -322,10 +328,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
322328

323329
return response, nil
324330
} else {
325-
// Pre-allocate response structs from pools
326-
// response := acquireOpenAIResponse()
327331
response := &schemas.BifrostChatResponse{}
328-
// defer releaseOpenAIResponse(response)
329332

330333
// Use enhanced response handler with pre-allocated response
331334
rawResponse, bifrostErr := handleProviderResponse(body, response, provider.sendBackRawResponse)
@@ -336,6 +339,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
336339
response.ExtraFields.RequestType = schemas.ChatCompletionRequest
337340
response.ExtraFields.Provider = schemas.Vertex
338341
response.ExtraFields.ModelRequested = request.Model
342+
response.ExtraFields.Latency = latency.Milliseconds()
339343

340344
if provider.sendBackRawResponse {
341345
response.ExtraFields.RawResponse = rawResponse
@@ -484,22 +488,15 @@ func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key,
484488
return nil, newConfigurationError("embedding input texts are empty", schemas.Vertex)
485489
}
486490

487-
// All Vertex AI embedding models use the same native Vertex embedding API
488-
return provider.handleVertexEmbedding(ctx, request.Model, key, reqBody, request.Params)
489-
}
490-
491-
// handleVertexEmbedding handles embedding requests using Vertex's native embedding API
492-
// This is used for all Vertex AI embedding models as they all use the same response format
493-
func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model string, key schemas.Key, vertexReq *vertex.VertexEmbeddingRequest, params *schemas.EmbeddingParameters) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
494491
// Use the typed request directly
495-
jsonBody, err := sonic.Marshal(vertexReq)
492+
jsonBody, err := sonic.Marshal(reqBody)
496493
if err != nil {
497494
return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex)
498495
}
499496

500497
// Build the native Vertex embedding API endpoint
501498
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict",
502-
key.VertexKeyConfig.Region, key.VertexKeyConfig.ProjectID, key.VertexKeyConfig.Region, model)
499+
key.VertexKeyConfig.Region, key.VertexKeyConfig.ProjectID, key.VertexKeyConfig.Region, request.Model)
503500

504501
// Create request
505502
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
@@ -532,6 +529,8 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model
532529
return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex)
533530
}
534531

532+
startTime := time.Now()
533+
535534
// Make request
536535
resp, err := client.Do(req)
537536
if err != nil {
@@ -554,6 +553,8 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model
554553
}
555554
defer resp.Body.Close()
556555

556+
latency := time.Since(startTime)
557+
557558
// Handle error response
558559
body, err := io.ReadAll(resp.Body)
559560
if err != nil {
@@ -598,8 +599,9 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model
598599

599600
// Set ExtraFields
600601
bifrostResponse.ExtraFields.Provider = schemas.Vertex
601-
bifrostResponse.ExtraFields.ModelRequested = model
602+
bifrostResponse.ExtraFields.ModelRequested = request.Model
602603
bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest
604+
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
603605

604606
// Set raw response if enabled
605607
if provider.sendBackRawResponse {

core/schemas/chatcompletions.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,21 @@ type ContentLogProb struct {
545545

546546
// BifrostLLMUsage represents token usage information
547547
type BifrostLLMUsage struct {
548-
PromptTokens int `json:"prompt_tokens,omitempty"`
549-
CompletionTokens int `json:"completion_tokens,omitempty"`
550-
TotalTokens int `json:"total_tokens"`
548+
PromptTokens int `json:"prompt_tokens,omitempty"`
549+
PromptTokensDetails *ChatPromptTokensDetails `json:"prompt_tokens_details,omitempty"`
550+
CompletionTokens int `json:"completion_tokens,omitempty"`
551+
CompletionTokensDetails *ChatCompletionTokensDetails `json:"completion_tokens_details,omitempty"`
552+
TotalTokens int `json:"total_tokens"`
553+
}
554+
555+
type ChatPromptTokensDetails struct {
556+
AudioTokens int `json:"audio_tokens,omitempty"`
557+
CachedTokens int `json:"cached_tokens,omitempty"`
558+
}
559+
560+
type ChatCompletionTokensDetails struct {
561+
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
562+
AudioTokens int `json:"audio_tokens,omitempty"`
563+
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
564+
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
551565
}

core/schemas/mux.go

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,56 @@ func ToChatMessages(rms []ResponsesMessage) []ChatMessage {
615615
return chatMessages
616616
}
617617

618+
func (cu *BifrostLLMUsage) ToResponsesResponseUsage() *ResponsesResponseUsage {
619+
if cu == nil {
620+
return nil
621+
}
622+
623+
usage := &ResponsesResponseUsage{
624+
InputTokens: cu.PromptTokens,
625+
OutputTokens: cu.CompletionTokens,
626+
TotalTokens: cu.TotalTokens,
627+
}
628+
629+
if cu.PromptTokensDetails != nil {
630+
usage.InputTokensDetails = &ResponsesResponseInputTokens{
631+
CachedTokens: cu.PromptTokensDetails.CachedTokens,
632+
}
633+
}
634+
if cu.CompletionTokensDetails != nil {
635+
usage.OutputTokensDetails = &ResponsesResponseOutputTokens{
636+
ReasoningTokens: cu.CompletionTokensDetails.ReasoningTokens,
637+
}
638+
}
639+
640+
return usage
641+
}
642+
643+
func (ru *ResponsesResponseUsage) ToBifrostLLMUsage() *BifrostLLMUsage {
644+
if ru == nil {
645+
return nil
646+
}
647+
648+
usage := &BifrostLLMUsage{
649+
PromptTokens: ru.InputTokens,
650+
CompletionTokens: ru.OutputTokens,
651+
TotalTokens: ru.TotalTokens,
652+
}
653+
654+
if ru.InputTokensDetails != nil {
655+
usage.PromptTokensDetails = &ChatPromptTokensDetails{
656+
CachedTokens: ru.InputTokensDetails.CachedTokens,
657+
}
658+
}
659+
if ru.OutputTokensDetails != nil {
660+
usage.CompletionTokensDetails = &ChatCompletionTokensDetails{
661+
ReasoningTokens: ru.OutputTokensDetails.ReasoningTokens,
662+
}
663+
}
664+
665+
return usage
666+
}
667+
618668
// =============================================================================
619669
// REQUEST CONVERSION METHODS
620670
// =============================================================================
@@ -805,15 +855,7 @@ func (cr *BifrostChatResponse) ToBifrostResponsesResponse() *BifrostResponsesRes
805855

806856
// Convert Usage if needed
807857
if cr.Usage != nil {
808-
responsesResp.Usage = &ResponsesResponseUsage{
809-
InputTokens: cr.Usage.PromptTokens,
810-
OutputTokens: cr.Usage.CompletionTokens,
811-
TotalTokens: cr.Usage.TotalTokens,
812-
}
813-
814-
if responsesResp.Usage.TotalTokens == 0 {
815-
responsesResp.Usage.TotalTokens = cr.Usage.PromptTokens + cr.Usage.CompletionTokens
816-
}
858+
responsesResp.Usage = cr.Usage.ToResponsesResponseUsage()
817859
}
818860

819861
// Copy other relevant fields
@@ -859,15 +901,7 @@ func (responsesResp *BifrostResponsesResponse) ToBifrostChatResponse() *BifrostC
859901
// Convert Usage if needed
860902
if responsesResp.Usage != nil {
861903
// Map Responses usage to Chat usage
862-
chatResp.Usage = &BifrostLLMUsage{
863-
PromptTokens: responsesResp.Usage.InputTokens,
864-
CompletionTokens: responsesResp.Usage.OutputTokens,
865-
TotalTokens: responsesResp.Usage.TotalTokens,
866-
}
867-
868-
if chatResp.Usage.TotalTokens == 0 {
869-
chatResp.Usage.TotalTokens = chatResp.Usage.PromptTokens + chatResp.Usage.CompletionTokens
870-
}
904+
chatResp.Usage = responsesResp.Usage.ToBifrostLLMUsage()
871905
}
872906

873907
// Copy other relevant fields
@@ -976,11 +1010,7 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse() *BifrostRespon
9761010
// Add usage information if present in the response
9771011
if cr.Usage != nil {
9781012
streamResp.Response = &BifrostResponsesResponse{
979-
Usage: &ResponsesResponseUsage{
980-
InputTokens: cr.Usage.PromptTokens,
981-
OutputTokens: cr.Usage.CompletionTokens,
982-
TotalTokens: cr.Usage.TotalTokens,
983-
},
1013+
Usage: cr.Usage.ToResponsesResponseUsage(),
9841014
}
9851015
}
9861016
} else {

core/schemas/providers/anthropic/chat.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"fmt"
66
"time"
7+
78
"github.com/maximhq/bifrost/core/schemas"
89
)
910

@@ -350,7 +351,10 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse() *schemas.Bifro
350351
// Convert usage information
351352
if response.Usage != nil {
352353
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
353-
PromptTokens: response.Usage.InputTokens,
354+
PromptTokens: response.Usage.InputTokens,
355+
PromptTokensDetails: &schemas.ChatPromptTokensDetails{
356+
CachedTokens: response.Usage.CacheCreationInputTokens + response.Usage.CacheReadInputTokens,
357+
},
354358
CompletionTokens: response.Usage.OutputTokens,
355359
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
356360
}
@@ -613,6 +617,11 @@ func ToAnthropicChatCompletionResponse(bifrostResp *schemas.BifrostChatResponse)
613617
InputTokens: bifrostResp.Usage.PromptTokens,
614618
OutputTokens: bifrostResp.Usage.CompletionTokens,
615619
}
620+
621+
//NOTE: We cannot segregate between cache creation and cache read tokens, so we will use the total cached tokens as the cache read tokens
622+
if bifrostResp.Usage.PromptTokensDetails != nil && bifrostResp.Usage.PromptTokensDetails.CachedTokens > 0 {
623+
anthropicResp.Usage.CacheReadInputTokens = bifrostResp.Usage.PromptTokensDetails.CachedTokens
624+
}
616625
}
617626

618627
// Convert choices to content

core/schemas/providers/cohere/chat.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@ func (response *CohereChatResponse) ToBifrostChatResponse() *schemas.BifrostChat
298298
if response.Usage.Tokens.OutputTokens != nil {
299299
usage.CompletionTokens = int(*response.Usage.Tokens.OutputTokens)
300300
}
301+
if response.Usage.CachedTokens != nil {
302+
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
303+
CachedTokens: int(*response.Usage.CachedTokens),
304+
}
305+
}
301306
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
302307
}
303308

core/schemas/providers/cohere/types.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,14 +500,13 @@ func (c *CohereStreamCitationStruct) UnmarshalJSON(data []byte) error {
500500
return fmt.Errorf("citations field is neither array nor object")
501501
}
502502

503-
504503
// CohereStreamMessage represents the message part of streaming deltas
505504
type CohereStreamMessage struct {
506-
Role *string `json:"role,omitempty"` // For message-start
507-
Content *CohereStreamContentStruct `json:"content,omitempty"` // For content events (object)
508-
ToolPlan *string `json:"tool_plan,omitempty"` // For tool-plan-delta
509-
ToolCalls *CohereStreamToolCallStruct `json:"tool_calls,omitempty"` // For tool-call events (flexible)
510-
Citations *CohereStreamCitationStruct `json:"citations,omitempty"` // For citation events
505+
Role *string `json:"role,omitempty"` // For message-start
506+
Content *CohereStreamContentStruct `json:"content,omitempty"` // For content events (object)
507+
ToolPlan *string `json:"tool_plan,omitempty"` // For tool-plan-delta
508+
ToolCalls *CohereStreamToolCallStruct `json:"tool_calls,omitempty"` // For tool-call events (flexible)
509+
Citations *CohereStreamCitationStruct `json:"citations,omitempty"` // For citation events
511510
}
512511

513512
// CohereStreamContent represents content in streaming events

core/schemas/providers/gemini/chat.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ func (response *GenerateContentResponse) ToBifrostChatResponse() *schemas.Bifros
335335
}
336336

337337
// Extract usage metadata
338-
inputTokens, outputTokens, totalTokens := response.extractUsageMetadata()
338+
inputTokens, outputTokens, totalTokens, cachedTokens, reasoningTokens := response.extractUsageMetadata()
339339

340340
// Process candidates to extract text content
341341
if len(response.Candidates) > 0 {
@@ -380,6 +380,12 @@ func (response *GenerateContentResponse) ToBifrostChatResponse() *schemas.Bifros
380380
PromptTokens: inputTokens,
381381
CompletionTokens: outputTokens,
382382
TotalTokens: totalTokens,
383+
PromptTokensDetails: &schemas.ChatPromptTokensDetails{
384+
CachedTokens: cachedTokens,
385+
},
386+
CompletionTokensDetails: &schemas.ChatCompletionTokensDetails{
387+
ReasoningTokens: reasoningTokens,
388+
},
383389
}
384390

385391
return bifrostResp
@@ -469,6 +475,12 @@ func ToGeminiChatResponse(bifrostResp *schemas.BifrostChatResponse) *GenerateCon
469475
CandidatesTokenCount: int32(bifrostResp.Usage.CompletionTokens),
470476
TotalTokenCount: int32(bifrostResp.Usage.TotalTokens),
471477
}
478+
if bifrostResp.Usage.PromptTokensDetails != nil {
479+
genaiResp.UsageMetadata.CachedContentTokenCount = int32(bifrostResp.Usage.PromptTokensDetails.CachedTokens)
480+
}
481+
if bifrostResp.Usage.CompletionTokensDetails != nil {
482+
genaiResp.UsageMetadata.ThoughtsTokenCount = int32(bifrostResp.Usage.CompletionTokensDetails.ReasoningTokens)
483+
}
472484
}
473485

474486
return genaiResp

core/schemas/providers/gemini/transcription.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (response *GenerateContentResponse) ToBifrostTranscriptionResponse() *schem
7777
bifrostResp := &schemas.BifrostTranscriptionResponse{}
7878

7979
// Extract usage metadata
80-
inputTokens, outputTokens, totalTokens := response.extractUsageMetadata()
80+
inputTokens, outputTokens, totalTokens, _, _ := response.extractUsageMetadata()
8181

8282
// Process candidates to extract text content
8383
if len(response.Candidates) > 0 {

core/schemas/providers/gemini/utils.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,16 @@ func ensureExtraParams(bifrostReq *schemas.BifrostChatRequest) {
153153
}
154154

155155
// extractUsageMetadata extracts usage metadata from the Gemini response
156-
func (r *GenerateContentResponse) extractUsageMetadata() (int, int, int) {
157-
var inputTokens, outputTokens, totalTokens int
156+
func (r *GenerateContentResponse) extractUsageMetadata() (int, int, int, int, int) {
157+
var inputTokens, outputTokens, totalTokens, cachedTokens, reasoningTokens int
158158
if r.UsageMetadata != nil {
159159
inputTokens = int(r.UsageMetadata.PromptTokenCount)
160160
outputTokens = int(r.UsageMetadata.CandidatesTokenCount)
161161
totalTokens = int(r.UsageMetadata.TotalTokenCount)
162+
cachedTokens = int(r.UsageMetadata.CachedContentTokenCount)
163+
reasoningTokens = int(r.UsageMetadata.ThoughtsTokenCount)
162164
}
163-
return inputTokens, outputTokens, totalTokens
165+
return inputTokens, outputTokens, totalTokens, cachedTokens, reasoningTokens
164166
}
165167

166168
// convertParamsToGenerationConfig converts Bifrost parameters to Gemini GenerationConfig

framework/changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
<!-- The pattern we follow here is to keep the changelog for the latest version -->
22
<!-- Old changelogs are automatically attached to the GitHub releases -->
33

4-
- chore: version update core to 1.2.12
4+
- chore: version update core to 1.2.12
5+
- feat: added support for vertex provider/model format in pricing lookup

0 commit comments

Comments
 (0)