diff --git a/.github/workflows/scripts/release-core.sh b/.github/workflows/scripts/release-core.sh index ed612fbef..a1027e309 100755 --- a/.github/workflows/scripts/release-core.sh +++ b/.github/workflows/scripts/release-core.sh @@ -32,10 +32,14 @@ fi # Building core go mod download go build ./... -go test ./... cd .. echo "βœ… Core build validation successful" +# Run core provider tests +echo "πŸ”§ Running core provider tests..." +cd tests/core-providers +go test -v ./... +cd ../.. # Capturing changelog CHANGELOG_BODY=$(cat core/changelog.md) diff --git a/.gitignore b/.gitignore index 09ff78d06..994ddf71c 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,6 @@ go.work.sum # Sqlite DBs *.db *.db-shm -*.db-wal \ No newline at end of file +*.db-wal + +.claude \ No newline at end of file diff --git a/core/bifrost.go b/core/bifrost.go index 88e239433..66a835130 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -1229,7 +1229,11 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR primaryResult, primaryErr := bifrost.tryRequest(ctx, req) if primaryErr != nil { - bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", provider, model, primaryErr)) + if primaryErr.Error != nil { + bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %s", provider, model, primaryErr.Error.Message)) + } else { + bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", provider, model, primaryErr)) + } if len(fallbacks) > 0 { bifrost.logger.Debug(fmt.Sprintf("Check if we should try %d fallbacks", len(fallbacks))) } @@ -1629,7 +1633,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas time.Sleep(backoff) } - bifrost.logger.Debug("attempting request for provider %s", provider.GetProviderKey()) + bifrost.logger.Debug("attempting %s request for provider %s", req.RequestType, provider.GetProviderKey()) // Attempt the request if IsStreamRequestType(req.RequestType) { @@ -1644,7 +1648,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } } - bifrost.logger.Debug("request for provider %s completed", provider.GetProviderKey()) + bifrost.logger.Debug("request %s for provider %s completed", req.RequestType, provider.GetProviderKey()) // Check if successful or if we should retry if bifrostError == nil || diff --git a/core/changelog.md b/core/changelog.md index a3dd9a0d1..184a81a5c 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -1,6 +1,9 @@ -- fix: openai specific parameters filtered for openai compatibile providers -- fix: error response unmarshalling for gemini provider -- BREAKING FIX: json_schema field correctly renamed to schema; ResponsesTextConfigFormatJSONSchema restructured \ No newline at end of file +- bug: fixed embedding request not being handled in `GetExtraFields()` method of `BifrostResponse` +- fix: added latency calculation for vertex native requests +- feat: added cached tokens and reasoning tokens to the usage metadata for chat completions +- feat: added global region support for vertex API +- fix: added filter for extra fields in chat completions request for Mistral provider +- fix: fixed ResponsesComputerToolCallPendingSafetyCheck code field \ No newline at end of file diff --git a/core/providers/gemini.go b/core/providers/gemini.go index 04783efc6..fb0463782 100644 --- a/core/providers/gemini.go +++ b/core/providers/gemini.go @@ -381,8 +381,8 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner scanner := bufio.NewScanner(resp.Body) // Increase buffer size to handle large chunks (especially for audio data) - buf := make([]byte, 0, 256*1024) // 256KB buffer - scanner.Buffer(buf, 1024*1024) // Allow up to 1MB tokens + buf := make([]byte, 0, 1024*1024) // 1MB initial buffer + scanner.Buffer(buf, 10*1024*1024) // Allow up to 10MB tokens chunkIndex := -1 usage := &schemas.SpeechUsage{} startTime := time.Now() @@ -658,6 +658,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) + // Increase buffer size to handle large chunks (especially for audio data) + buf := make([]byte, 0, 1024*1024) // 1MB initial buffer + scanner.Buffer(buf, 10*1024*1024) // Allow up to 10MB tokens chunkIndex := -1 usage := &schemas.TranscriptionUsage{} startTime := time.Now() @@ -674,8 +677,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo } var jsonData string // Parse SSE data - if strings.HasPrefix(line, "data: ") { - jsonData = strings.TrimPrefix(line, "data: ") + if after, ok := strings.CutPrefix(line, "data: "); ok { + jsonData = after } else { // Handle raw JSON errors (without "data: " prefix) jsonData = line diff --git a/core/providers/groq.go b/core/providers/groq.go index e2bfa5c8d..368fa1fbe 100644 --- a/core/providers/groq.go +++ b/core/providers/groq.go @@ -124,13 +124,18 @@ func (provider *GroqProvider) TextCompletionStream(ctx context.Context, postHook responseChan <- response continue } - response.ToTextCompletionResponse() - if response.BifrostTextCompletionResponse != nil { - response.BifrostTextCompletionResponse.ExtraFields.RequestType = schemas.TextCompletionRequest - response.BifrostTextCompletionResponse.ExtraFields.Provider = provider.GetProviderKey() - response.BifrostTextCompletionResponse.ExtraFields.ModelRequested = request.Model + if response.BifrostChatResponse != nil { + textCompletionResponse := response.BifrostChatResponse.ToTextCompletionResponse() + if textCompletionResponse != nil { + textCompletionResponse.ExtraFields.RequestType = schemas.TextCompletionRequest + textCompletionResponse.ExtraFields.Provider = provider.GetProviderKey() + textCompletionResponse.ExtraFields.ModelRequested = request.Model + + responseChan <- &schemas.BifrostStream{ + BifrostTextCompletionResponse: textCompletionResponse, + } + } } - responseChan <- response } }() return responseChan, nil diff --git a/core/providers/vertex.go b/core/providers/vertex.go index 015dcd7e6..0edb91b17 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -203,10 +203,19 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. return nil, newConfigurationError("region is not set in key config", schemas.Vertex) } - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) - + var url string if strings.Contains(request.Model, "claude") { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) + if region == "global" { + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) + } else { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) + } + } else { + if region == "global" { + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/openapi/chat/completions", projectID) + } else { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) + } } // Create request @@ -286,12 +295,19 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. } var openAIErr schemas.BifrostError - var vertexErr []VertexError + var vertexErr []VertexError if err := sonic.Unmarshal(body, &openAIErr); err != nil { // Try Vertex error format if OpenAI format fails if err := sonic.Unmarshal(body, &vertexErr); err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex) + + //try with single Vertex error format + var vertexErr VertexError + if err := sonic.Unmarshal(body, &vertexErr); err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex) + } + + return nil, newProviderAPIError(vertexErr.Error.Message, nil, resp.StatusCode, schemas.Vertex, nil, nil) } if len(vertexErr) > 0 { @@ -395,7 +411,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo delete(requestBody, "model") delete(requestBody, "region") - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) + var url string + if region == "global" { + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) + } else { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) + } // Prepare headers for Vertex Anthropic headers := map[string]string{ @@ -418,7 +439,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo provider.logger, ) } else { - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) + var url string + if region == "global" { + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/openapi/chat/completions", projectID) + } else { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) + } authHeader := map[string]string{} if key.Value != "" { authHeader["Authorization"] = "Bearer " + key.Value diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index da53de874..e639a0182 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -232,6 +232,8 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { return &r.ResponsesResponse.ExtraFields case r.ResponsesStreamResponse != nil: return &r.ResponsesStreamResponse.ExtraFields + case r.EmbeddingResponse != nil: + return &r.EmbeddingResponse.ExtraFields case r.SpeechResponse != nil: return &r.SpeechResponse.ExtraFields case r.SpeechStreamResponse != nil: diff --git a/core/schemas/providers/gemini/chat.go b/core/schemas/providers/gemini/chat.go index f43d5c2be..42ee936e4 100644 --- a/core/schemas/providers/gemini/chat.go +++ b/core/schemas/providers/gemini/chat.go @@ -31,10 +31,10 @@ func (request *GeminiGenerationRequest) ToBifrostChatRequest() *schemas.BifrostC allGenAiMessages := []Content{} if request.SystemInstruction != nil { - allGenAiMessages = append(allGenAiMessages, request.SystemInstruction.ToGenAIContent()) + allGenAiMessages = append(allGenAiMessages, *request.SystemInstruction) } for _, content := range request.Contents { - allGenAiMessages = append(allGenAiMessages, content.ToGenAIContent()) + allGenAiMessages = append(allGenAiMessages, content) } for _, content := range allGenAiMessages { diff --git a/core/schemas/providers/gemini/embedding.go b/core/schemas/providers/gemini/embedding.go index f258b1a65..69e80a6fa 100644 --- a/core/schemas/providers/gemini/embedding.go +++ b/core/schemas/providers/gemini/embedding.go @@ -26,8 +26,8 @@ func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *Gemi // Create the Gemini embedding request request := &GeminiEmbeddingRequest{ Model: bifrostReq.Model, - Content: &CustomContent{ - Parts: []*CustomPart{ + Content: &Content{ + Parts: []*Part{ { Text: text, }, diff --git a/core/schemas/providers/gemini/responses.go b/core/schemas/providers/gemini/responses.go index f0a3cecf9..8f7940ecb 100644 --- a/core/schemas/providers/gemini/responses.go +++ b/core/schemas/providers/gemini/responses.go @@ -482,21 +482,21 @@ func convertPropertyToGeminiSchema(prop interface{}) *Schema { } // convertResponsesMessagesToGeminiContents converts Responses messages to Gemini contents -func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessage) ([]CustomContent, *CustomContent, error) { - var contents []CustomContent - var systemInstruction *CustomContent +func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessage) ([]Content, *Content, error) { + var contents []Content + var systemInstruction *Content for _, msg := range messages { // Handle system messages separately if msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleSystem { if systemInstruction == nil { - systemInstruction = &CustomContent{} + systemInstruction = &Content{} } // Convert system message content if msg.Content != nil { if msg.Content.ContentStr != nil { - systemInstruction.Parts = append(systemInstruction.Parts, &CustomPart{ + systemInstruction.Parts = append(systemInstruction.Parts, &Part{ Text: *msg.Content.ContentStr, }) } @@ -517,7 +517,7 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag } // Handle regular messages - content := CustomContent{} + content := Content{} if msg.Role != nil { content.Role = string(*msg.Role) @@ -528,7 +528,7 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag // Convert message content if msg.Content != nil { if msg.Content.ContentStr != nil { - content.Parts = append(content.Parts, &CustomPart{ + content.Parts = append(content.Parts, &Part{ Text: *msg.Content.ContentStr, }) } @@ -559,7 +559,7 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag } } - part := &CustomPart{ + part := &Part{ FunctionCall: &FunctionCall{ Name: *msg.ResponsesToolMessage.Name, Args: argsMap, @@ -586,7 +586,7 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag funcName = *msg.ResponsesToolMessage.CallID } - part := &CustomPart{ + part := &Part{ FunctionResponse: &FunctionResponse{ Name: funcName, Response: responseMap, @@ -608,11 +608,11 @@ func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessag } // convertContentBlockToGeminiPart converts a content block to Gemini part -func convertContentBlockToGeminiPart(block schemas.ResponsesMessageContentBlock) (*CustomPart, error) { +func convertContentBlockToGeminiPart(block schemas.ResponsesMessageContentBlock) (*Part, error) { switch block.Type { case schemas.ResponsesInputMessageContentBlockTypeText: if block.Text != nil { - return &CustomPart{ + return &Part{ Text: *block.Text, }, nil } @@ -645,14 +645,14 @@ func convertContentBlockToGeminiPart(block schemas.ResponsesMessageContentBlock) return nil, fmt.Errorf("failed to decode base64 image data: %w", err) } - return &CustomPart{ - InlineData: &CustomBlob{ + return &Part{ + InlineData: &Blob{ MIMEType: mimeType, Data: decodedData, }, }, nil } else { - return &CustomPart{ + return &Part{ FileData: &FileData{ MIMEType: mimeType, FileURI: sanitizedURL, @@ -669,8 +669,8 @@ func convertContentBlockToGeminiPart(block schemas.ResponsesMessageContentBlock) return nil, fmt.Errorf("failed to decode base64 audio data: %w", err) } - return &CustomPart{ - InlineData: &CustomBlob{ + return &Part{ + InlineData: &Blob{ MIMEType: func() string { f := strings.ToLower(strings.TrimSpace(block.Audio.Format)) if f == "" { @@ -689,15 +689,15 @@ func convertContentBlockToGeminiPart(block schemas.ResponsesMessageContentBlock) case schemas.ResponsesInputMessageContentBlockTypeFile: if block.ResponsesInputMessageContentBlockFile != nil { if block.ResponsesInputMessageContentBlockFile.FileURL != nil { - return &CustomPart{ + return &Part{ FileData: &FileData{ MIMEType: "application/octet-stream", // default FileURI: *block.ResponsesInputMessageContentBlockFile.FileURL, }, }, nil } else if block.ResponsesInputMessageContentBlockFile.FileData != nil { - return &CustomPart{ - InlineData: &CustomBlob{ + return &Part{ + InlineData: &Blob{ MIMEType: "application/octet-stream", // default Data: []byte(*block.ResponsesInputMessageContentBlockFile.FileData), }, diff --git a/core/schemas/providers/gemini/speech.go b/core/schemas/providers/gemini/speech.go index e0be68e00..0ef9f7f62 100644 --- a/core/schemas/providers/gemini/speech.go +++ b/core/schemas/providers/gemini/speech.go @@ -27,9 +27,9 @@ func ToGeminiSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest, responseMod // Convert speech input to Gemini format if bifrostReq.Input.Input != "" { - geminiReq.Contents = []CustomContent{ + geminiReq.Contents = []Content{ { - Parts: []*CustomPart{ + Parts: []*Part{ { Text: bifrostReq.Input.Input, }, diff --git a/core/schemas/providers/gemini/transcription.go b/core/schemas/providers/gemini/transcription.go index 7df9c6905..b627a4f20 100644 --- a/core/schemas/providers/gemini/transcription.go +++ b/core/schemas/providers/gemini/transcription.go @@ -47,7 +47,7 @@ func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionReques } // Create parts for the transcription request - parts := []*CustomPart{ + parts := []*Part{ { Text: prompt, }, @@ -55,15 +55,15 @@ func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionReques // Add audio file if present if len(bifrostReq.Input.File) > 0 { - parts = append(parts, &CustomPart{ - InlineData: &CustomBlob{ + parts = append(parts, &Part{ + InlineData: &Blob{ MIMEType: detectAudioMimeType(bifrostReq.Input.File), Data: bifrostReq.Input.File, }, }) } - geminiReq.Contents = []CustomContent{ + geminiReq.Contents = []Content{ { Parts: parts, }, diff --git a/core/schemas/providers/gemini/types.go b/core/schemas/providers/gemini/types.go index f97a3ea56..3ac9e0580 100644 --- a/core/schemas/providers/gemini/types.go +++ b/core/schemas/providers/gemini/types.go @@ -1,11 +1,8 @@ package gemini import ( - "encoding/base64" "encoding/json" - "fmt" "reflect" - "strings" "time" ) @@ -53,9 +50,9 @@ const ( type GeminiGenerationRequest struct { Model string `json:"model,omitempty"` // Model field for explicit model specification - Contents []CustomContent `json:"contents,omitempty"` // For chat completion requests + Contents []Content `json:"contents,omitempty"` // For chat completion requests Requests []GeminiEmbeddingRequest `json:"requests,omitempty"` // For batch embedding requests - SystemInstruction *CustomContent `json:"systemInstruction,omitempty"` + SystemInstruction *Content `json:"systemInstruction,omitempty"` GenerationConfig GenerationConfig `json:"generationConfig,omitempty"` SafetySettings []SafetySetting `json:"safetySettings,omitempty"` Tools []Tool `json:"tools,omitempty"` @@ -860,89 +857,11 @@ type GenerationConfigThinkingConfig struct { // EmbeddingRequest represents a single embedding request in a batch type GeminiEmbeddingRequest struct { - Content *CustomContent `json:"content,omitempty"` - TaskType *string `json:"taskType,omitempty"` - Title *string `json:"title,omitempty"` - OutputDimensionality *int `json:"outputDimensionality,omitempty"` - Model string `json:"model,omitempty"` -} - -// CustomBlob handles URL-safe base64 decoding for Google GenAI requests -type CustomBlob struct { - Data []byte `json:"data,omitempty"` - MIMEType string `json:"mimeType,omitempty"` -} - -// UnmarshalJSON custom unmarshalling to handle URL-safe base64 encoding -func (b *CustomBlob) UnmarshalJSON(data []byte) error { - // First unmarshal into a temporary struct with string data - var temp struct { - Data string `json:"data,omitempty"` - MIMEType string `json:"mimeType,omitempty"` - } - - if err := json.Unmarshal(data, &temp); err != nil { - return err - } - - b.MIMEType = temp.MIMEType - - if temp.Data != "" { - // Convert URL-safe base64 to standard base64 - standardBase64 := strings.ReplaceAll(strings.ReplaceAll(temp.Data, "_", "/"), "-", "+") - - // Add padding if necessary - switch len(standardBase64) % 4 { - case 2: - standardBase64 += "==" - case 3: - standardBase64 += "=" - } - - decoded, err := base64.StdEncoding.DecodeString(standardBase64) - if err != nil { - return fmt.Errorf("failed to decode base64 data: %v", err) - } - b.Data = decoded - } - - return nil -} - -// CustomPart handles Google GenAI Part with custom Blob unmarshalling -type CustomPart struct { - VideoMetadata *VideoMetadata `json:"videoMetadata,omitempty"` - Thought bool `json:"thought,omitempty"` - CodeExecutionResult *CodeExecutionResult `json:"codeExecutionResult,omitempty"` - ExecutableCode *ExecutableCode `json:"executableCode,omitempty"` - FileData *FileData `json:"fileData,omitempty"` - FunctionCall *FunctionCall `json:"functionCall,omitempty"` - FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` - InlineData *CustomBlob `json:"inlineData,omitempty"` - Text string `json:"text,omitempty"` -} - -// ToGenAIPart converts CustomPart to Part -func (p *CustomPart) ToGenAIPart() *Part { - part := &Part{ - VideoMetadata: p.VideoMetadata, - Thought: p.Thought, - CodeExecutionResult: p.CodeExecutionResult, - ExecutableCode: p.ExecutableCode, - FileData: p.FileData, - FunctionCall: p.FunctionCall, - FunctionResponse: p.FunctionResponse, - Text: p.Text, - } - - if p.InlineData != nil { - part.InlineData = &Blob{ - Data: p.InlineData.Data, - MIMEType: p.InlineData.MIMEType, - } - } - - return part + Content *Content `json:"content,omitempty"` + TaskType *string `json:"taskType,omitempty"` + Title *string `json:"title,omitempty"` + OutputDimensionality *int `json:"outputDimensionality,omitempty"` + Model string `json:"model,omitempty"` } // Contains the multi-part content of a message. @@ -956,25 +875,6 @@ type Content struct { Role string `json:"role,omitempty"` } -// CustomContent handles Google GenAI Content with custom Part unmarshalling -type CustomContent struct { - Parts []*CustomPart `json:"parts,omitempty"` - Role string `json:"role,omitempty"` -} - -// ToGenAIContent converts CustomContent to genai_sdk.Content -func (c *CustomContent) ToGenAIContent() Content { - parts := make([]*Part, len(c.Parts)) - for i, part := range c.Parts { - parts[i] = part.ToGenAIPart() - } - - return Content{ - Parts: parts, - Role: c.Role, - } -} - // A datatype containing media content. // Exactly one field within a Part should be set, representing the specific type // of content being conveyed. Using multiple fields within the same `Part` diff --git a/core/schemas/providers/gemini/utils.go b/core/schemas/providers/gemini/utils.go index e6078c2ae..8af92a7e7 100644 --- a/core/schemas/providers/gemini/utils.go +++ b/core/schemas/providers/gemini/utils.go @@ -352,21 +352,21 @@ func addSpeechConfigToGenerationConfig(config *GenerationConfig, voiceConfig *sc } // convertBifrostMessagesToGemini converts Bifrost messages to Gemini format -func convertBifrostMessagesToGemini(messages []schemas.ChatMessage) []CustomContent { - var contents []CustomContent +func convertBifrostMessagesToGemini(messages []schemas.ChatMessage) []Content { + var contents []Content for _, message := range messages { - var parts []*CustomPart + var parts []*Part // Handle content if message.Content.ContentStr != nil && *message.Content.ContentStr != "" { - parts = append(parts, &CustomPart{ + parts = append(parts, &Part{ Text: *message.Content.ContentStr, }) } else if message.Content.ContentBlocks != nil { for _, block := range message.Content.ContentBlocks { if block.Text != nil { - parts = append(parts, &CustomPart{ + parts = append(parts, &Part{ Text: *block.Text, }) } @@ -389,7 +389,7 @@ func convertBifrostMessagesToGemini(messages []schemas.ChatMessage) []CustomCont if toolCall.ID != nil && strings.TrimSpace(*toolCall.ID) != "" { callID = *toolCall.ID } - parts = append(parts, &CustomPart{ + parts = append(parts, &Part{ FunctionCall: &FunctionCall{ ID: callID, Name: *toolCall.Function.Name, @@ -442,7 +442,7 @@ func convertBifrostMessagesToGemini(messages []schemas.ChatMessage) []CustomCont callID = *message.ChatToolMessage.ToolCallID } - parts = append(parts, &CustomPart{ + parts = append(parts, &Part{ FunctionResponse: &FunctionResponse{ ID: callID, Name: callID, // Gemini uses name for correlation @@ -452,7 +452,7 @@ func convertBifrostMessagesToGemini(messages []schemas.ChatMessage) []CustomCont } if len(parts) > 0 { - content := CustomContent{ + content := Content{ Parts: parts, Role: string(message.Role), } diff --git a/core/schemas/providers/openai/chat.go b/core/schemas/providers/openai/chat.go index edd4f0a8b..203878416 100644 --- a/core/schemas/providers/openai/chat.go +++ b/core/schemas/providers/openai/chat.go @@ -39,6 +39,21 @@ func ToOpenAIChatRequest(bifrostReq *schemas.BifrostChatRequest) *OpenAIChatRequ // Removing extra parameters that are not supported by Gemini openaiReq.ServiceTier = nil return openaiReq + case schemas.Mistral: + openaiReq.filterOpenAISpecificParameters() + + // Remove max_completion_tokens and replace with max_tokens + if openaiReq.MaxCompletionTokens != nil { + openaiReq.MaxTokens = openaiReq.MaxCompletionTokens + openaiReq.MaxCompletionTokens = nil + } + + // Mistral does not support ToolChoiceStruct, only simple tool choice strings are supported. + if openaiReq.ToolChoice != nil && openaiReq.ToolChoice.ChatToolChoiceStruct != nil { + openaiReq.ToolChoice.ChatToolChoiceStr = schemas.Ptr("required") + openaiReq.ToolChoice.ChatToolChoiceStruct = nil + } + return openaiReq default: openaiReq.filterOpenAISpecificParameters() return openaiReq diff --git a/core/schemas/providers/openai/types.go b/core/schemas/providers/openai/types.go index 7ad13b276..5fe1ba647 100644 --- a/core/schemas/providers/openai/types.go +++ b/core/schemas/providers/openai/types.go @@ -38,6 +38,10 @@ type OpenAIChatRequest struct { schemas.ChatParameters Stream *bool `json:"stream,omitempty"` + + //NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. + // This Field is populated only for such providers and is NOT to be used externally. + MaxTokens *int `json:"max_tokens,omitempty"` } // IsStreamingRequested implements the StreamingRequest interface diff --git a/core/schemas/responses.go b/core/schemas/responses.go index f64a27e82..46fe63816 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -567,7 +567,7 @@ type ResponsesComputerToolCall struct { // ResponsesComputerToolCallPendingSafetyCheck represents a pending safety check type ResponsesComputerToolCallPendingSafetyCheck struct { ID string `json:"id"` - Context string `json:"context"` + Code string `json:"code"` Message string `json:"message"` } diff --git a/core/version b/core/version index f2ae0b4a2..0b1f1edf1 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.2.12 +1.2.13 diff --git a/docs/architecture/core/concurrency.mdx b/docs/architecture/core/concurrency.mdx index 660390628..83c9aa806 100644 --- a/docs/architecture/core/concurrency.mdx +++ b/docs/architecture/core/concurrency.mdx @@ -268,29 +268,32 @@ The backpressure system protects Bifrost from being overwhelmed while maintainin ### **Thread-Safe Object Pools** ```mermaid -graph TB - subgraph "sync.Pool Architecture" - GetObject[Get Object
sync.Pool.Get()] +graph TD + subgraph "sync.Pool Lifecycle" + direction LR + GetObject[Get Object
sync.Pool.Get] + PoolCheck{Is Pool Empty?} NewObject[New Object
Factory Function] UseObject[Use Object
Application Logic] ResetObject[Reset Object
Clear State] - ReturnObject[Return Object
sync.Pool.Put()] + ReturnObject[Return Object
sync.Pool.Put] + + GetObject --> PoolCheck + PoolCheck -- Yes --> NewObject + PoolCheck -- No --> UseObject + NewObject --> UseObject + UseObject --> ResetObject + ResetObject --> ReturnObject + ReturnObject --> GetObject end - subgraph "GC Integration" + subgraph "GC Interaction" + direction TB GCRun[GC Runs] - PoolCleanup[Pool Cleanup
Automatic] - Reallocation[Object Reallocation
as Needed] + PoolCleanup[Pool Cleanup
Removes idle objects] + + GCRun --> PoolCleanup end - - GetObject --> NewObject - NewObject --> UseObject - UseObject --> ResetObject - ResetObject --> ReturnObject - ReturnObject --> GetObject - - GCRun --> PoolCleanup - PoolCleanup --> Reallocation ``` **Thread-Safe Pool Architecture:** diff --git a/docs/features/plugins/mocker.mdx b/docs/features/plugins/mocker.mdx index a6c6ddebe..ebbc949e4 100644 --- a/docs/features/plugins/mocker.mdx +++ b/docs/features/plugins/mocker.mdx @@ -39,8 +39,10 @@ func main() { } defer client.Shutdown() - // All requests will now return: "This is a mock response from the Mocker plugin" - response, _ := client.ChatCompletionRequest(context.Background(), &schemas.BifrostChatRequest{ + // All chat and responses requests will now return: "This is a mock response from the Mocker plugin" + + // Chat completion request + chatResponse, _ := client.ChatCompletionRequest(context.Background(), &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", Input: []schemas.ChatMessage{ @@ -52,6 +54,20 @@ func main() { }, }, }) + + // Responses request + responsesResponse, _ := client.ResponsesRequest(context.Background(), &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Input: []schemas.ResponsesMessage{ + { + Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: bifrost.Ptr("Hello!"), + }, + }, + }, + }) } ``` @@ -86,6 +102,30 @@ plugin, err := mocker.NewMockerPlugin(mocker.MockerConfig{ }) ``` +### Responses Request Example + +The mocker plugin automatically handles both chat completion and responses requests with the same configuration: + +```go +// This rule will work for both ChatCompletionRequest and ResponsesRequest +{ + Name: "universal-mock", + Enabled: true, + Probability: 1.0, + Conditions: mocker.Conditions{ + MessageRegex: stringPtr("(?i).*hello.*"), + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + Message: "Hello! I'm a mock response that works for both request types.", + }, + }, + }, +} +``` + ## Installation Add the plugin to your project: @@ -137,6 +177,29 @@ config := mocker.MockerConfig{ } ``` +## Supported Request Types + +The Mocker plugin supports the following Bifrost request types: + +- **Chat Completion Requests** (`ChatCompletionRequest`) - Standard chat-based interactions +- **Responses Requests** (`ResponsesRequest`) - OpenAI-compatible responses API format +- **Skip Context Key** - Use `"skip-mocker"` context key to bypass mocking per request + +### Skip Mocker for Specific Requests + +You can skip the mocker plugin for specific requests by adding a context key: + +```go +import "github.com/maximhq/bifrost/core/schemas" + +// Create context that skips mocker +ctx := context.WithValue(context.Background(), + schemas.BifrostContextKey("skip-mocker"), true) + +// This request will bypass the mocker and go to the real provider +response, err := client.ChatCompletionRequest(ctx, request) +``` + ## Key Features ### Template Variables @@ -469,6 +532,27 @@ Response{ } ``` +### Skip Mocker Not Working + +Ensure you're using the correct context key format: + +```go +// βœ… Correct +ctx := context.WithValue(context.Background(), + schemas.BifrostContextKey("skip-mocker"), true) + +// ❌ Wrong +ctx := context.WithValue(context.Background(), "skip-mocker", true) +``` + +### Responses Request Issues + +If responses requests aren't being mocked: + +1. Verify the plugin supports `ResponsesRequest` (version 1.2.13+) +2. Check that your regex patterns match the message content +3. Ensure the request type is `schemas.ResponsesRequest` + ### Debug Mode Enable debug logging to troubleshoot: diff --git a/framework/changelog.md b/framework/changelog.md index 0e4795521..d539ca829 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -1,5 +1,5 @@ -- chore: version update core to 1.2.12 +- chore: version update core to 1.2.13 - feat: added support for vertex provider/model format in pricing lookup \ No newline at end of file diff --git a/framework/version b/framework/version index e9bc14996..645377eea 100644 --- a/framework/version +++ b/framework/version @@ -1 +1 @@ -1.1.14 +1.1.15 diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md index a434134ee..51f9eb708 100644 --- a/plugins/governance/changelog.md +++ b/plugins/governance/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.12 and framework to 1.1.14 +- chore: version update core to 1.2.13 and framework to 1.1.15 diff --git a/plugins/governance/version b/plugins/governance/version index 5bdcf5c39..25b22e060 100644 --- a/plugins/governance/version +++ b/plugins/governance/version @@ -1 +1 @@ -1.3.15 +1.3.16 diff --git a/plugins/jsonparser/changelog.md b/plugins/jsonparser/changelog.md index a434134ee..51f9eb708 100644 --- a/plugins/jsonparser/changelog.md +++ b/plugins/jsonparser/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.12 and framework to 1.1.14 +- chore: version update core to 1.2.13 and framework to 1.1.15 diff --git a/plugins/jsonparser/version b/plugins/jsonparser/version index 22122dbf4..92ee6ac2f 100644 --- a/plugins/jsonparser/version +++ b/plugins/jsonparser/version @@ -1 +1 @@ -1.3.14 \ No newline at end of file +1.3.15 \ No newline at end of file diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md index a434134ee..51f9eb708 100644 --- a/plugins/logging/changelog.md +++ b/plugins/logging/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.12 and framework to 1.1.14 +- chore: version update core to 1.2.13 and framework to 1.1.15 diff --git a/plugins/logging/version b/plugins/logging/version index 085c0f266..5bdcf5c39 100644 --- a/plugins/logging/version +++ b/plugins/logging/version @@ -1 +1 @@ -1.3.14 +1.3.15 diff --git a/plugins/maxim/changelog.md b/plugins/maxim/changelog.md index a434134ee..51f9eb708 100644 --- a/plugins/maxim/changelog.md +++ b/plugins/maxim/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.12 and framework to 1.1.14 +- chore: version update core to 1.2.13 and framework to 1.1.15 diff --git a/plugins/maxim/version b/plugins/maxim/version index 323afbcd2..8a3b8ac20 100644 --- a/plugins/maxim/version +++ b/plugins/maxim/version @@ -1 +1 @@ -1.4.14 +1.4.15 diff --git a/plugins/mocker/changelog.md b/plugins/mocker/changelog.md index a434134ee..c7113783b 100644 --- a/plugins/mocker/changelog.md +++ b/plugins/mocker/changelog.md @@ -1,4 +1,6 @@ -- chore: version update core to 1.2.12 and framework to 1.1.14 +- chore: version update core to 1.2.13 and framework to 1.1.15 +- feat: added support for responses request +- feat: added "skip-mocker" context key to skip mocker plugin per request diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go index e9740a2da..5069ddefa 100644 --- a/plugins/mocker/main.go +++ b/plugins/mocker/main.go @@ -6,6 +6,7 @@ import ( "maps" "math/rand" "regexp" + "slices" "sort" "strings" "sync" @@ -491,6 +492,17 @@ func (p *MockerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest return req, nil, nil } + skipMocker, ok := (*ctx).Value(schemas.BifrostContextKey("skip-mocker")).(bool) + if ok && skipMocker { + return req, nil, nil + } + + if req.RequestType != schemas.ChatCompletionRequest && req.RequestType != schemas.ResponsesRequest { + return req, nil, nil + } + + startTime := time.Now() + // Track total request count using atomic operation (no lock needed) atomic.AddInt64(&p.totalRequests, 1) @@ -530,7 +542,7 @@ func (p *MockerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest // Generate appropriate mock response based on type if response.Type == ResponseTypeSuccess { - return p.generateSuccessShortCircuit(req, response) + return p.generateSuccessShortCircuit(req, response, startTime) } else if response.Type == ResponseTypeError { return p.generateErrorShortCircuit(req, response) } @@ -586,18 +598,12 @@ func (p *MockerPlugin) findMatchingCompiledRule(req *schemas.BifrostRequest) *co // matchesConditionsFast checks if request matches rule conditions with optimized performance func (p *MockerPlugin) matchesConditionsFast(req *schemas.BifrostRequest, conditions *Conditions, compiledRegex *regexp.Regexp) bool { - provider, _, _ := req.GetRequestFields() + provider, model, _ := req.GetRequestFields() // Check providers - optimized string comparison if len(conditions.Providers) > 0 { providerStr := string(provider) - found := false - for _, provider := range conditions.Providers { - if providerStr == provider { - found = true - break - } - } + found := slices.Contains(conditions.Providers, providerStr) if !found { return false } @@ -606,8 +612,8 @@ func (p *MockerPlugin) matchesConditionsFast(req *schemas.BifrostRequest, condit // Check models - direct string comparison if len(conditions.Models) > 0 { found := false - for _, model := range conditions.Models { - if model == model { + for _, conditionModel := range conditions.Models { + if model == conditionModel { found = true break } @@ -756,7 +762,7 @@ func (p *MockerPlugin) calculateRequestSizeFast(req *schemas.BifrostRequest) int } // generateSuccessShortCircuit creates a success response short-circuit with optimized allocations -func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, response *Response) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { +func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, response *Response, startTime time.Time) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { if response.Content == nil { return req, nil, nil } @@ -799,8 +805,10 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, provider, model, _ := req.GetRequestFields() // Create mock response with proper structure - mockResponse := &schemas.BifrostResponse{ - ChatResponse: &schemas.BifrostChatResponse{ + mockResponse := &schemas.BifrostResponse{} + + if req.RequestType == schemas.ChatCompletionRequest { + mockResponse.ChatResponse = &schemas.BifrostChatResponse{ Model: model, Usage: &usage, Choices: []schemas.BifrostResponseChoice{ @@ -821,8 +829,33 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, RequestType: schemas.ChatCompletionRequest, Provider: provider, ModelRequested: model, + Latency: int64(time.Since(startTime).Milliseconds()), }, - }, + } + } else if req.RequestType == schemas.ResponsesRequest { + mockResponse.ResponsesResponse = &schemas.BifrostResponsesResponse{ + CreatedAt: int(time.Now().Unix()), + Output: []schemas.ResponsesMessage{ + { + Role: bifrost.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &message, + }, + Type: bifrost.Ptr(schemas.ResponsesMessageTypeMessage), + }, + }, + Usage: &schemas.ResponsesResponseUsage{ + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesRequest, + Provider: provider, + ModelRequested: model, + Latency: int64(time.Since(startTime).Milliseconds()), + }, + } } // Override model if specified @@ -859,6 +892,8 @@ func (p *MockerPlugin) generateErrorShortCircuit(req *schemas.BifrostRequest, re return req, nil, nil } + provider, model, _ := req.GetRequestFields() + errorContent := response.Error allowFallbacks := response.AllowFallbacks @@ -868,6 +903,11 @@ func (p *MockerPlugin) generateErrorShortCircuit(req *schemas.BifrostRequest, re Message: errorContent.Message, }, AllowFallbacks: allowFallbacks, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: req.RequestType, + Provider: provider, + ModelRequested: model, + }, } // Set error type diff --git a/plugins/mocker/version b/plugins/mocker/version index 22122dbf4..92ee6ac2f 100644 --- a/plugins/mocker/version +++ b/plugins/mocker/version @@ -1 +1 @@ -1.3.14 \ No newline at end of file +1.3.15 \ No newline at end of file diff --git a/plugins/otel/changelog.md b/plugins/otel/changelog.md index a434134ee..bc74baa33 100644 --- a/plugins/otel/changelog.md +++ b/plugins/otel/changelog.md @@ -1,4 +1,6 @@ -- chore: version update core to 1.2.12 and framework to 1.1.14 +- chore: version update core to 1.2.13 and framework to 1.1.15 +- feat: added headers support for OTel configuration. Value prefixed with env will be fetched from environment variables (env.) +- feat: emission of OTel resource spans is completely async - this brings down inference overhead to < 1Β΅second \ No newline at end of file diff --git a/plugins/otel/version b/plugins/otel/version index 97bceaaf6..758a46e9b 100644 --- a/plugins/otel/version +++ b/plugins/otel/version @@ -1 +1 @@ -1.0.14 \ No newline at end of file +1.0.15 \ No newline at end of file diff --git a/plugins/semanticcache/changelog.md b/plugins/semanticcache/changelog.md index a434134ee..5781539e6 100644 --- a/plugins/semanticcache/changelog.md +++ b/plugins/semanticcache/changelog.md @@ -1,4 +1,5 @@ -- chore: version update core to 1.2.12 and framework to 1.1.14 +- chore: version update core to 1.2.13 and framework to 1.1.15 +- tests: added mocker plugin to all chat/responses tests diff --git a/plugins/semanticcache/plugin_edge_cases_test.go b/plugins/semanticcache/plugin_edge_cases_test.go index b641b519e..64f12bbf5 100644 --- a/plugins/semanticcache/plugin_edge_cases_test.go +++ b/plugins/semanticcache/plugin_edge_cases_test.go @@ -227,25 +227,6 @@ func TestContentVariations(t *testing.T) { name string request *schemas.BifrostChatRequest }{ - { - name: "Unicode Content", - request: &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: []schemas.ChatMessage{ - { - Role: schemas.ChatMessageRoleUser, - Content: &schemas.ChatMessageContent{ - ContentStr: bifrost.Ptr("🌟 Unicode test: Hello, δΈ–η•Œ! Ω…Ψ±Ψ­Ψ¨Ψ§ 🌍"), - }, - }, - }, - Params: &schemas.ChatParameters{ - MaxCompletionTokens: bifrost.Ptr(50), - Temperature: bifrost.Ptr(0.1), - }, - }, - }, { name: "Image URL Content", request: &schemas.BifrostChatRequest{ diff --git a/plugins/semanticcache/plugin_responses_test.go b/plugins/semanticcache/plugin_responses_test.go index eaa70265e..a578c0c5e 100644 --- a/plugins/semanticcache/plugin_responses_test.go +++ b/plugins/semanticcache/plugin_responses_test.go @@ -38,6 +38,13 @@ func TestResponsesAPIBasicFunctionality(t *testing.T) { t.Logf("First request completed in %v", duration1) t.Logf("Response contains %d output messages", len(response1.ResponsesResponse.Output)) + if c := response1.ResponsesResponse.Output[0].Content; c != nil && c.ContentStr != nil { + t.Logf("Response: %s", *c.ContentStr) + } else if c != nil && len(c.ContentBlocks) > 0 && c.ContentBlocks[0].Text != nil { + t.Logf("Response: %s", *c.ContentBlocks[0].Text) + } else { + t.Log("Response: ") + } // Wait for cache to be written WaitForCache() @@ -56,6 +63,11 @@ func TestResponsesAPIBasicFunctionality(t *testing.T) { if response2 == nil || len(response2.Output) == 0 { t.Fatal("Second Responses response is invalid") } + if response2.Output[0].Content.ContentStr != nil { + t.Logf("Response: %s", *response2.Output[0].Content.ContentStr) + } else { + t.Logf("Response: %v", *response2.Output[0].Content.ContentBlocks[0].Text) + } t.Logf("Second request completed in %v", duration2) @@ -334,7 +346,7 @@ func TestResponsesAPINoStoreFlag(t *testing.T) { // TestResponsesAPIStreaming tests streaming Responses API requests func TestResponsesAPIStreaming(t *testing.T) { t.Log("Responses streaming not supported yet") - + setup := NewTestSetup(t) defer setup.Cleanup() diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index b33ab4a35..8934a2b0d 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -11,6 +11,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/vectorstore" + mocker "github.com/maximhq/bifrost/plugins/mocker" ) // getWeaviateConfigFromEnv retrieves Weaviate configuration from environment variables @@ -110,6 +111,217 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelPr }, nil } +// getMockRules returns a list of mock rules for the semantic cache tests +func getMockRules() []mocker.MockRule { + return []mocker.MockRule{ + // Core test prompts + { + Name: "bifrost-definition", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)What is Bifrost.*")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Bifrost is a unified API for interacting with multiple AI providers."}}, + }, + }, + { + Name: "machine-learning-explanation", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)what is machine learning\\?|explain machine learning|machine learning concepts|can you explain machine learning|explain the basics of machine learning")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Machine learning is a field of AI that uses statistical techniques to give computer systems the ability to learn from data."}}, + }, + }, + { + Name: "ai-explanation", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)what is artificial intelligence\\?|can you explain what ai is\\?|define artificial intelligence")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Artificial intelligence is the simulation of human intelligence in machines."}}, + }, + }, + { + Name: "capital-of-france", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("What is the capital of France\\?")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "The capital of France is Paris."}}, + }, + }, + { + Name: "newton-laws", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)describe.*newton.*three laws|describe.*three laws.*newton")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Newton's three laws of motion are: 1. An object at rest stays at rest and an object in motion stays in motion with the same speed and in the same direction unless acted upon by an unbalanced force. 2. The acceleration of an object as produced by a net force is directly proportional to the magnitude of the net force, in the same direction as the net force, and inversely proportional to the mass of the object. 3. For every action, there is an equal and opposite reaction."}}, + }, + }, + // Weather-related prompts + { + Name: "weather-question", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)what.*weather|weather.*like")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "It's sunny today with a temperature of 72Β°F."}}, + }, + }, + // Blockchain and deep learning + { + Name: "blockchain-definition", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)define blockchain|blockchain technology")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Blockchain is a distributed ledger technology that maintains a continuously growing list of records."}}, + }, + }, + { + Name: "deep-learning", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)what is deep learning")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Deep learning is a subset of machine learning that uses neural networks with multiple layers."}}, + }, + }, + // Quantum computing + { + Name: "quantum-computing", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)quantum computing|explain quantum")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Quantum computing uses quantum mechanical phenomena to process information in ways that classical computers cannot."}}, + }, + }, + // Conversation prompts + { + Name: "hello-greeting", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)^hello$|^hi$|hello.*world")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Hello! How can I help you today?"}}, + }, + }, + { + Name: "how-are-you", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)how are you")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "I'm doing well, thank you for asking!"}}, + }, + }, + { + Name: "meaning-of-life", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)meaning of life")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "The meaning of life is a philosophical question that has been pondered for centuries. Some say it's 42!"}}, + }, + }, + { + Name: "short-story", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)tell me.*short story")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Once upon a time, there was a brave knight who saved the day."}}, + }, + }, + // Test-specific prompts + { + Name: "test-configuration", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)test configuration")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "This is a test configuration response."}}, + }, + }, + { + Name: "test-messages", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)test.*message|test.*no-store|test.*cache|test.*error|ttl test|threshold test|provider.*test|edge case test")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "This is a test response for various test scenarios."}}, + }, + }, + { + Name: "long-prompt", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)very long prompt")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "This is a response to a very long prompt."}}, + }, + }, + { + Name: "parameter-tests", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)test.*parameters|performance test")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Parameter test response with various settings."}}, + }, + }, + // Dynamic message patterns (for conversation tests) + { + Name: "message-pattern", + Enabled: true, + Conditions: mocker.Conditions{MessageRegex: bifrost.Ptr("(?i)message \\d+")}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "Response to numbered message."}}, + }, + }, + // Default catch-all rule (lowest priority) + { + Name: "default-mock", + Enabled: true, + Priority: -1, // Lower priority + Conditions: mocker.Conditions{}, + Probability: 1.0, + Responses: []mocker.Response{ + {Type: mocker.ResponseTypeSuccess, Content: &mocker.SuccessResponse{Message: "This is a generic mocked response."}}, + }, + }, + } +} + +// getMockedBifrostClient creates a Bifrost client with a mocker plugin for testing +func getMockedBifrostClient(t *testing.T, ctx context.Context, logger schemas.Logger, semanticCachePlugin schemas.Plugin) *bifrost.Bifrost { + mockerCfg := mocker.MockerConfig{ + Enabled: true, + Rules: getMockRules(), + } + + mockerPlugin, err := mocker.Init(mockerCfg) + if err != nil { + t.Fatalf("Failed to initialize mocker plugin: %v", err) + } + + account := &BaseAccount{} + client, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: account, + Plugins: []schemas.Plugin{semanticCachePlugin, mockerPlugin}, + Logger: logger, + }) + if err != nil { + t.Fatalf("Error initializing Bifrost with mocker: %v", err) + } + + return client +} + // TestSetup contains common test setup components type TestSetup struct { Logger schemas.Logger @@ -121,10 +333,6 @@ type TestSetup struct { // NewTestSetup creates a new test setup with default configuration func NewTestSetup(t *testing.T) *TestSetup { - // if os.Getenv("OPENAI_API_KEY") == "" { - // t.Skip("OPENAI_API_KEY is not set, skipping test") - // } - return NewTestSetupWithConfig(t, &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", @@ -145,6 +353,7 @@ func NewTestSetupWithConfig(t *testing.T, config *Config) *TestSetup { ctx := context.Background() logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug) + // Keep Weaviate for embeddings, as mocker only affects chat completions store, err := vectorstore.NewVectorStore(context.Background(), &vectorstore.Config{ Type: vectorstore.VectorStoreTypeWeaviate, Config: getWeaviateConfigFromEnv(), @@ -163,15 +372,8 @@ func NewTestSetupWithConfig(t *testing.T, config *Config) *TestSetup { pluginImpl := plugin.(*Plugin) clearTestKeysWithStore(t, pluginImpl.store) - account := &BaseAccount{} - client, err := bifrost.Init(ctx, schemas.BifrostConfig{ - Account: account, - Plugins: []schemas.Plugin{plugin}, - Logger: logger, - }) - if err != nil { - t.Fatalf("Error initializing Bifrost: %v", err) - } + // Get a mocked Bifrost client + client := getMockedBifrostClient(t, ctx, logger, plugin) return &TestSetup{ Logger: logger, @@ -371,10 +573,6 @@ func CreateContextWithCacheKeyAndNoStore(value string, noStore bool) context.Con // CreateTestSetupWithConversationThreshold creates a test setup with custom conversation history threshold func CreateTestSetupWithConversationThreshold(t *testing.T, threshold int) *TestSetup { - if os.Getenv("OPENAI_API_KEY") == "" { - t.Skip("OPENAI_API_KEY is not set, skipping test") - } - config := &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", @@ -395,10 +593,6 @@ func CreateTestSetupWithConversationThreshold(t *testing.T, threshold int) *Test // CreateTestSetupWithExcludeSystemPrompt creates a test setup with ExcludeSystemPrompt setting func CreateTestSetupWithExcludeSystemPrompt(t *testing.T, excludeSystem bool) *TestSetup { - if os.Getenv("OPENAI_API_KEY") == "" { - t.Skip("OPENAI_API_KEY is not set, skipping test") - } - config := &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", @@ -419,10 +613,6 @@ func CreateTestSetupWithExcludeSystemPrompt(t *testing.T, excludeSystem bool) *T // CreateTestSetupWithThresholdAndExcludeSystem creates a test setup with both conversation threshold and exclude system prompt settings func CreateTestSetupWithThresholdAndExcludeSystem(t *testing.T, threshold int, excludeSystem bool) *TestSetup { - if os.Getenv("OPENAI_API_KEY") == "" { - t.Skip("OPENAI_API_KEY is not set, skipping test") - } - config := &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", diff --git a/plugins/semanticcache/version b/plugins/semanticcache/version index 085c0f266..5bdcf5c39 100644 --- a/plugins/semanticcache/version +++ b/plugins/semanticcache/version @@ -1 +1 @@ -1.3.14 +1.3.15 diff --git a/plugins/telemetry/changelog.md b/plugins/telemetry/changelog.md index a434134ee..51f9eb708 100644 --- a/plugins/telemetry/changelog.md +++ b/plugins/telemetry/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.12 and framework to 1.1.14 +- chore: version update core to 1.2.13 and framework to 1.1.15 diff --git a/plugins/telemetry/version b/plugins/telemetry/version index 22122dbf4..92ee6ac2f 100644 --- a/plugins/telemetry/version +++ b/plugins/telemetry/version @@ -1 +1 @@ -1.3.14 \ No newline at end of file +1.3.15 \ No newline at end of file diff --git a/tests/core-providers/anthropic_test.go b/tests/core-providers/anthropic_test.go index a80bb155c..3229593bb 100644 --- a/tests/core-providers/anthropic_test.go +++ b/tests/core-providers/anthropic_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,19 +10,24 @@ import ( ) func TestAnthropic(t *testing.T) { + if os.Getenv("ANTHROPIC_API_KEY") == "" { + t.Skip("Skipping Anthropic tests because ANTHROPIC_API_KEY is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ - Provider: schemas.Anthropic, - ChatModel: "claude-sonnet-4-20250514", - VisionModel: "claude-3-7-sonnet-20250219", // Same model supports vision - TextModel: "", // Anthropic doesn't support text completion - EmbeddingModel: "", // Anthropic doesn't support embedding + Provider: schemas.Anthropic, + ChatModel: "claude-sonnet-4-20250514", + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, + {Provider: schemas.Anthropic, Model: "claude-sonnet-4-20250514"}, + }, + VisionModel: "claude-3-7-sonnet-20250219", // Same model supports vision Scenarios: config.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, @@ -40,5 +46,8 @@ func TestAnthropic(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("AnthropicTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/azure_test.go b/tests/core-providers/azure_test.go index f30f454a5..2d702dc03 100644 --- a/tests/core-providers/azure_test.go +++ b/tests/core-providers/azure_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,17 +10,24 @@ import ( ) func TestAzure(t *testing.T) { + if os.Getenv("AZURE_API_KEY") == "" { + t.Skip("Skipping Azure tests because AZURE_API_KEY is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ - Provider: schemas.Azure, - ChatModel: "gpt-4o", - VisionModel: "gpt-4o", + Provider: schemas.Azure, + ChatModel: "gpt-4o", + VisionModel: "gpt-4o", + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Azure, Model: "gpt-4o-mini"}, + {Provider: schemas.Azure, Model: "gpt-4.1"}, + }, TextModel: "", // Azure OpenAI doesn't support text completion in newer models EmbeddingModel: "text-embedding-ada-002", Scenarios: config.TestScenarios{ @@ -39,5 +47,15 @@ func TestAzure(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + // Disable embedding if embeddings key is not provided + if os.Getenv("AZURE_EMB_API_KEY") == "" { + t.Logf("AZURE_EMB_API_KEY not set; disabling Azure embedding tests") + testConfig.EmbeddingModel = "" + testConfig.Scenarios.Embedding = false + } + + t.Run("AzureTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/bedrock_test.go b/tests/core-providers/bedrock_test.go index 51cd72aab..82b68853e 100644 --- a/tests/core-providers/bedrock_test.go +++ b/tests/core-providers/bedrock_test.go @@ -19,14 +19,17 @@ func TestBedrock(t *testing.T) { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ - Provider: schemas.Bedrock, - ChatModel: "claude-sonnet-4", - VisionModel: "claude-sonnet-4", + Provider: schemas.Bedrock, + ChatModel: "anthropic.claude-3-5-sonnet-20240620-v1:0", + VisionModel: "claude-sonnet-4", + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Bedrock, Model: "claude-3.7-sonnet"}, + }, TextModel: "mistral.mistral-7b-instruct-v0:2", // Bedrock Claude doesn't support text completion - EmbeddingModel: "amazon.titan-embed-text-v2:0", + EmbeddingModel: "cohere.embed-v4:0", + ReasoningModel: "claude-sonnet-4", Scenarios: config.TestScenarios{ TextCompletion: false, // Not supported for Claude SimpleChat: true, @@ -36,13 +39,17 @@ func TestBedrock(t *testing.T) { MultipleToolCalls: true, End2EndToolCalling: true, AutomaticFunctionCall: true, - ImageURL: false, + ImageURL: false, // Direct Image URL is not supported for Bedrock ImageBase64: true, - MultipleImages: false, + MultipleImages: false, // Direct Image URL is not supported for Bedrock CompleteEnd2End: true, Embedding: true, + Reasoning: true, }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("BedrockTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/cerebras_test.go b/tests/core-providers/cerebras_test.go index d3513d477..d5c9fbffe 100644 --- a/tests/core-providers/cerebras_test.go +++ b/tests/core-providers/cerebras_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,16 +10,23 @@ import ( ) func TestCerebras(t *testing.T) { + if os.Getenv("CEREBRAS_API_KEY") == "" { + t.Skip("Skipping Cerebras tests because CEREBRAS_API_KEY is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ - Provider: schemas.Cerebras, - ChatModel: "llama-3.3-70b", + Provider: schemas.Cerebras, + ChatModel: "llama-3.3-70b", + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Cerebras, Model: "llama3.1-8b"}, + {Provider: schemas.Cerebras, Model: "gpt-oss-120b"}, + }, TextModel: "llama3.1-8b", EmbeddingModel: "", // Cerebras doesn't support embedding Scenarios: config.TestScenarios{ @@ -39,5 +47,8 @@ func TestCerebras(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("CerebrasTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/cohere_test.go b/tests/core-providers/cohere_test.go index e5ae253d5..7fcaec002 100644 --- a/tests/core-providers/cohere_test.go +++ b/tests/core-providers/cohere_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,12 +10,15 @@ import ( ) func TestCohere(t *testing.T) { + if os.Getenv("COHERE_API_KEY") == "" { + t.Skip("Skipping Cohere tests because COHERE_API_KEY is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ Provider: schemas.Cohere, @@ -33,12 +37,15 @@ func TestCohere(t *testing.T) { AutomaticFunctionCall: true, // May not support automatic ImageURL: false, // Supported by c4ai-aya-vision-8b model ImageBase64: true, // Supported by c4ai-aya-vision-8b model - MultipleImages: true, // Supported by c4ai-aya-vision-8b model + MultipleImages: false, // Supported by c4ai-aya-vision-8b model CompleteEnd2End: false, Embedding: true, Reasoning: true, }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("CohereTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/config/account.go b/tests/core-providers/config/account.go index fc4d97363..0ab40f146 100644 --- a/tests/core-providers/config/account.go +++ b/tests/core-providers/config/account.go @@ -41,17 +41,21 @@ type TestScenarios struct { // ComprehensiveTestConfig extends TestConfig with additional scenarios type ComprehensiveTestConfig struct { - Provider schemas.ModelProvider - TextModel string - ChatModel string - VisionModel string - ReasoningModel string - EmbeddingModel string - TranscriptionModel string - SpeechSynthesisModel string - Scenarios TestScenarios - Fallbacks []schemas.Fallback - SkipReason string // Reason to skip certain tests + Provider schemas.ModelProvider + TextModel string + ChatModel string + VisionModel string + ReasoningModel string + EmbeddingModel string + TranscriptionModel string + SpeechSynthesisModel string + Scenarios TestScenarios + Fallbacks []schemas.Fallback // for chat, responses, image and reasoning tests + TextCompletionFallbacks []schemas.Fallback // for text completion tests + TranscriptionFallbacks []schemas.Fallback // for transcription tests + SpeechSynthesisFallbacks []schemas.Fallback // for speech synthesis tests + EmbeddingFallbacks []schemas.Fallback // for embedding tests + SkipReason string // Reason to skip certain tests } // ComprehensiveTestAccount provides a test implementation of the Account interface for comprehensive testing. @@ -100,7 +104,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context case ProviderOpenAICustom: return []schemas.Key{ { - Value: os.Getenv("GROQ_API_KEY"), // Use GROQ API key for OpenAI-compatible endpoint + Value: os.Getenv("OPENAI_API_KEY"), // Use GROQ API key for OpenAI-compatible endpoint Models: []string{}, Weight: 1.0, }, @@ -125,12 +129,13 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context Region: bifrost.Ptr(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: bifrost.Ptr(os.Getenv("AWS_ARN")), Deployments: map[string]string{ - "claude-sonnet-4": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-sonnet-4": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", }, }, }, { - Models: []string{"amazon.titan-embed-text-v2:0"}, + Models: []string{"anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.embed-v4:0"}, Weight: 1.0, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: os.Getenv("AWS_ACCESS_KEY_ID"), @@ -152,12 +157,26 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context return []schemas.Key{ { Value: os.Getenv("AZURE_API_KEY"), - Models: []string{"gpt-4o", "text-embedding-ada-002"}, + Models: []string{"gpt-4o"}, Weight: 1.0, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: os.Getenv("AZURE_ENDPOINT"), Deployments: map[string]string{ - "gpt-4o": "gpt-4o-aug", + "gpt-4o": "gpt-4o-aug", + }, + // Use environment variable for API version with fallback to current preview version + // Note: This is a preview API version that may change over time. Update as needed. + // Set AZURE_API_VERSION environment variable to override the default. + APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-08-01-preview")), + }, + }, + { + Value: os.Getenv("AZURE_EMB_API_KEY"), + Models: []string{}, + Weight: 1.0, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: os.Getenv("AZURE_EMB_ENDPOINT"), + Deployments: map[string]string{ "text-embedding-ada-002": "text-embedding-ada-002", }, // Use environment variable for API version with fallback to current preview version @@ -184,7 +203,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context return []schemas.Key{ { Value: os.Getenv("MISTRAL_API_KEY"), - Models: []string{"mistral-large-2411", "pixtral-12b-latest", "mistral-embed"}, + Models: []string{}, Weight: 1.0, }, }, nil @@ -239,27 +258,27 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema case schemas.OpenAI: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 3, // Higher retries for production-grade provider RetryBackoffInitial: 500 * time.Millisecond, RetryBackoffMax: 8 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case ProviderOpenAICustom: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - BaseURL: getEnvWithDefault("GROQ_OPENAI_BASE_URL", "https://api.groq.com/openai"), - DefaultRequestTimeoutInSeconds: 60, + BaseURL: "https://api.openai.com", + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 4, // Higher retries for Groq (can be flaky) RetryBackoffInitial: 1 * time.Second, RetryBackoffMax: 10 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, CustomProviderConfig: &schemas.CustomProviderConfig{ @@ -279,86 +298,86 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema case schemas.Anthropic: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 3, // Claude is generally reliable RetryBackoffInitial: 500 * time.Millisecond, RetryBackoffMax: 8 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case schemas.Bedrock: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 5, // AWS services can have occasional issues RetryBackoffInitial: 5 * time.Second, RetryBackoffMax: 20 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case schemas.Cohere: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 4, // Cohere can be variable RetryBackoffInitial: 750 * time.Millisecond, RetryBackoffMax: 10 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case schemas.Azure: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 3, // Azure OpenAI is generally reliable RetryBackoffInitial: 500 * time.Millisecond, RetryBackoffMax: 8 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case schemas.Vertex: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 3, // Google Cloud is generally reliable RetryBackoffInitial: 500 * time.Millisecond, RetryBackoffMax: 8 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case schemas.Ollama: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 2, // Local service, fewer retries needed RetryBackoffInitial: 250 * time.Millisecond, RetryBackoffMax: 4 * time.Second, - BaseURL: getEnvWithDefault("OLLAMA_BASE_URL", "http://localhost:11434"), + BaseURL: os.Getenv("OLLAMA_BASE_URL"), }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case schemas.Mistral: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 4, // Mistral can be variable RetryBackoffInitial: 750 * time.Millisecond, RetryBackoffMax: 10 * time.Second, @@ -371,13 +390,13 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema case schemas.Groq: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 5, // Groq can be flaky at times RetryBackoffInitial: 1 * time.Second, RetryBackoffMax: 15 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 2, BufferSize: 10, }, }, nil @@ -385,65 +404,65 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ BaseURL: os.Getenv("SGL_BASE_URL"), - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 5, // SGL (self-hosted) can be variable RetryBackoffInitial: 1 * time.Second, RetryBackoffMax: 15 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case schemas.Parasail: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 5, // Parasail can be variable RetryBackoffInitial: 1 * time.Second, RetryBackoffMax: 12 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil case schemas.Cerebras: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 4, // Cerebras is reasonably stable RetryBackoffInitial: 750 * time.Millisecond, RetryBackoffMax: 10 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 2, BufferSize: 10, }, }, nil case schemas.Gemini: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 4, // Gemini can be variable RetryBackoffInitial: 750 * time.Millisecond, RetryBackoffMax: 12 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, + Concurrency: 20, + BufferSize: 20, }, }, nil case schemas.OpenRouter: return &schemas.ProviderConfig{ NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 60, + DefaultRequestTimeoutInSeconds: 120, MaxRetries: 4, // OpenRouter can be variable (proxy service) RetryBackoffInitial: 1 * time.Second, RetryBackoffMax: 12 * time.Second, }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, + Concurrency: 10, BufferSize: 10, }, }, nil diff --git a/tests/core-providers/cross_provider_test.go b/tests/core-providers/cross_provider_test.go index 96baf31eb..3bdf36ecb 100644 --- a/tests/core-providers/cross_provider_test.go +++ b/tests/core-providers/cross_provider_test.go @@ -10,6 +10,9 @@ import ( ) func TestCrossProviderScenarios(t *testing.T) { + t.Skip("Skipping cross provider scenarios test") + return + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) @@ -110,6 +113,9 @@ func TestCrossProviderScenarios(t *testing.T) { } func TestCrossProviderConsistency(t *testing.T) { + t.Skip("Skipping cross provider consistency test") + return + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) diff --git a/tests/core-providers/custom_test.go b/tests/core-providers/custom_test.go deleted file mode 100644 index c0246d2e0..000000000 --- a/tests/core-providers/custom_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package tests - -import ( - "os" - "strings" - "testing" - - "github.com/maximhq/bifrost/tests/core-providers/config" - - bifrost "github.com/maximhq/bifrost/core" - "github.com/maximhq/bifrost/core/schemas" - "github.com/stretchr/testify/assert" -) - -func TestCustomProvider(t *testing.T) { - client, ctx, cancel, err := config.SetupTest() - if err != nil { - t.Fatalf("Error initializing test setup: %v", err) - } - defer cancel() - defer client.Shutdown() - - testConfig := config.ComprehensiveTestConfig{ - Provider: config.ProviderOpenAICustom, - ChatModel: "llama-3.3-70b-versatile", - TextModel: "", // OpenAI doesn't support text completion in newer models - EmbeddingModel: "", // groq custom base: embeddings not supported - Scenarios: config.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - MultipleToolCalls: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - CompleteEnd2End: true, - Embedding: false, - }, - } - - runAllComprehensiveTests(t, client, ctx, testConfig) -} - -func TestCustomProvider_DisallowedOperation(t *testing.T) { - // Skip test if required API key is not available - if os.Getenv("GROQ_API_KEY") == "" { - t.Skipf("skipping test: GROQ_API_KEY not set") - } - - client, ctx, cancel, err := config.SetupTest() - if err != nil { - t.Fatalf("Error initializing test setup: %v", err) - } - defer cancel() - defer client.Shutdown() - - // Create a speech request to the custom provider - prompt := "The future of artificial intelligence is" - request := &schemas.BifrostSpeechRequest{ - Provider: config.ProviderOpenAICustom, // Use the custom provider - Model: "llama-3.3-70b-versatile", // Use a model that exists for this provider - Input: &schemas.SpeechInput{ - Input: prompt, - }, - Params: &schemas.SpeechParameters{ - VoiceConfig: &schemas.SpeechVoiceInput{ - Voice: bifrost.Ptr("alloy"), - }, - ResponseFormat: "mp3", - }, - } - - // Attempt to make a speech stream request - response, bifrostErr := client.SpeechStreamRequest(ctx, request) - - // Assert that the request failed with an error - assert.NotNil(t, bifrostErr, "Expected error for disallowed speech stream operation") - assert.Nil(t, response, "Expected no response for disallowed operation") - - // Assert that the error message contains "not supported" or "not supported by openai-custom" - msg := strings.ToLower(bifrostErr.Error.Message) - assert.Contains(t, msg, "not supported", "error should indicate operation is not supported") - assert.Contains(t, msg, string(config.ProviderOpenAICustom), "error should mention refusing provider") - assert.Equal(t, config.ProviderOpenAICustom, bifrostErr.ExtraFields.Provider, "error should be attributed to the custom provider") -} - -func TestCustomProvider_MismatchedIdentity(t *testing.T) { - client, ctx, cancel, err := config.SetupTest() - if err != nil { - t.Fatalf("Error initializing test setup: %v", err) - } - defer cancel() - defer client.Shutdown() - - // Use a provider that doesn't exist - wrongProvider := schemas.ModelProvider("wrong-provider") - - request := &schemas.BifrostChatRequest{ - Provider: wrongProvider, - Model: "llama-3.3-70b-versatile", - Input: []schemas.ChatMessage{ - { - Role: schemas.ChatMessageRoleUser, - Content: &schemas.ChatMessageContent{ - ContentStr: bifrost.Ptr("Hello! What's the capital of France?"), - }, - }, - }, - Params: &schemas.ChatParameters{ - MaxCompletionTokens: bifrost.Ptr(100), - }, - } - - // Attempt to make a chat completion request - response, bifrostErr := client.ChatCompletionRequest(ctx, request) - - // Assert that the request failed with an error - assert.NotNil(t, bifrostErr, "Expected error for mismatched identity") - assert.Nil(t, response, "Expected no response for mismatched identity") - - msg := strings.ToLower(bifrostErr.Error.Message) - assert.Contains(t, msg, "unsupported provider", "error should mention unsupported provider") - assert.Contains(t, msg, strings.ToLower(string(wrongProvider)), "error should mention the wrong provider") - assert.Equal(t, wrongProvider, bifrostErr.ExtraFields.Provider, "error should include the unsupported provider identity") -} diff --git a/tests/core-providers/gemini_test.go b/tests/core-providers/gemini_test.go index c0d08498e..a9c731567 100644 --- a/tests/core-providers/gemini_test.go +++ b/tests/core-providers/gemini_test.go @@ -11,7 +11,7 @@ import ( func TestGemini(t *testing.T) { if os.Getenv("GEMINI_API_KEY") == "" { - t.Skip("GEMINI_API_KEY not set; skipping Gemini tests") + t.Skip("Skipping Gemini tests because GEMINI_API_KEY is not set") } client, ctx, cancel, err := config.SetupTest() @@ -19,16 +19,18 @@ func TestGemini(t *testing.T) { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ Provider: schemas.Gemini, ChatModel: "gemini-2.0-flash", VisionModel: "gemini-2.0-flash", - TextModel: "", // Gemini doesn't support text completion EmbeddingModel: "text-embedding-004", TranscriptionModel: "gemini-2.5-flash", SpeechSynthesisModel: "gemini-2.5-flash-preview-tts", + SpeechSynthesisFallbacks: []schemas.Fallback{ + {Provider: schemas.Gemini, Model: "gemini-2.5-pro-preview-tts"}, + }, + ReasoningModel: "gemini-2.5-pro", Scenarios: config.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, @@ -40,15 +42,19 @@ func TestGemini(t *testing.T) { AutomaticFunctionCall: true, ImageURL: false, ImageBase64: true, - MultipleImages: true, + MultipleImages: false, CompleteEnd2End: true, Embedding: true, - Transcription: true, - TranscriptionStream: true, + Transcription: false, + TranscriptionStream: false, SpeechSynthesis: true, SpeechSynthesisStream: true, + Reasoning: false, //TODO: Supported but lost since we map Gemini's responses via chat completions, fix is a native Gemini handler or reasoning support in chat completions }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("GeminiTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/groq_test.go b/tests/core-providers/groq_test.go index 7313d8f2f..5cfd8662c 100644 --- a/tests/core-providers/groq_test.go +++ b/tests/core-providers/groq_test.go @@ -1,6 +1,8 @@ package tests import ( + "context" + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,18 +11,27 @@ import ( ) func TestGroq(t *testing.T) { + if os.Getenv("GROQ_API_KEY") == "" { + t.Skip("Skipping Groq tests because GROQ_API_KEY is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ - Provider: schemas.Groq, - ChatModel: "llama-3.3-70b-versatile", - TextModel: "llama-3.3-70b-versatile", // Use same model for text completion (via conversion) - EmbeddingModel: "", // Groq doesn't support embedding + Provider: schemas.Groq, + ChatModel: "llama-3.3-70b-versatile", + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Groq, Model: "openai/gpt-oss-120b"}, + }, + TextModel: "llama-3.3-70b-versatile", // Use same model for text completion (via conversion) + TextCompletionFallbacks: []schemas.Fallback{ + {Provider: schemas.Groq, Model: "openai/gpt-oss-20b"}, + }, + EmbeddingModel: "", // Groq doesn't support embedding Scenarios: config.TestScenarios{ TextCompletion: true, // Supported via chat completion conversion TextCompletionStream: true, // Supported via chat completion streaming conversion @@ -39,5 +50,10 @@ func TestGroq(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + ctx = context.WithValue(ctx, schemas.BifrostContextKey("x-litellm-fallback"), "true") + + t.Run("GroqTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/mistral_test.go b/tests/core-providers/mistral_test.go index 8a05ca69d..17c241fdb 100644 --- a/tests/core-providers/mistral_test.go +++ b/tests/core-providers/mistral_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,19 +10,24 @@ import ( ) func TestMistral(t *testing.T) { + if os.Getenv("MISTRAL_API_KEY") == "" { + t.Skip("Skipping Mistral tests because MISTRAL_API_KEY is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ - Provider: schemas.Mistral, - ChatModel: "pixtral-12b-latest", + Provider: schemas.Mistral, + ChatModel: "mistral-medium-2508", + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Mistral, Model: "mistral-small-2503"}, + }, VisionModel: "pixtral-12b-latest", - TextModel: "", // Mistral doesn't support text completion in newer models - EmbeddingModel: "mistral-embed", + EmbeddingModel: "codestral-embed", Scenarios: config.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, @@ -39,5 +45,8 @@ func TestMistral(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("MistralTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/ollama_test.go b/tests/core-providers/ollama_test.go index 31d2ff3ef..c43383988 100644 --- a/tests/core-providers/ollama_test.go +++ b/tests/core-providers/ollama_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,12 +10,15 @@ import ( ) func TestOllama(t *testing.T) { + if os.Getenv("OLLAMA_BASE_URL") == "" { + t.Skip("Skipping Ollama tests because OLLAMA_BASE_URL is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ Provider: schemas.Ollama, @@ -38,5 +42,8 @@ func TestOllama(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("OllamaTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/openai_test.go b/tests/core-providers/openai_test.go index 27cc6f4f6..8bb6491f6 100644 --- a/tests/core-providers/openai_test.go +++ b/tests/core-providers/openai_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,20 +10,29 @@ import ( ) func TestOpenAI(t *testing.T) { + if os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("Skipping OpenAI tests because OPENAI_API_KEY is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ - Provider: schemas.OpenAI, - TextModel: "gpt-3.5-turbo-instruct", - ChatModel: "gpt-4o-mini", - VisionModel: "gpt-4o", - EmbeddingModel: "text-embedding-3-small", - TranscriptionModel: "whisper-1", + Provider: schemas.OpenAI, + TextModel: "gpt-3.5-turbo-instruct", + ChatModel: "gpt-4o-mini", + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o"}, + }, + VisionModel: "gpt-4o", + EmbeddingModel: "text-embedding-3-small", + TranscriptionModel: "gpt-4o-transcribe", + TranscriptionFallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "whisper-1"}, + }, SpeechSynthesisModel: "gpt-4o-mini-tts", ReasoningModel: "gpt-5", Scenarios: config.TestScenarios{ @@ -48,5 +58,8 @@ func TestOpenAI(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("OpenAITests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/openrouter_test.go b/tests/core-providers/openrouter_test.go index 038aeeb21..bb4b805aa 100644 --- a/tests/core-providers/openrouter_test.go +++ b/tests/core-providers/openrouter_test.go @@ -11,14 +11,14 @@ import ( func TestOpenRouter(t *testing.T) { if os.Getenv("OPENROUTER_API_KEY") == "" { - t.Skip("OPENROUTER_API_KEY not set; skipping OpenRouter tests") + t.Skip("Skipping OpenRouter tests because OPENROUTER_API_KEY is not set") } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ Provider: schemas.OpenRouter, @@ -33,14 +33,17 @@ func TestOpenRouter(t *testing.T) { MultiTurnConversation: true, ToolCalls: true, MultipleToolCalls: true, - End2EndToolCalling: true, + End2EndToolCalling: false, // OpenRouter's responses API is in Beta AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, + ImageURL: false, // OpenRouter's responses API is in Beta + ImageBase64: false, // OpenRouter's responses API is in Beta + MultipleImages: false, // OpenRouter's responses API is in Beta + CompleteEnd2End: false, // OpenRouter's responses API is in Beta }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("OpenRouterTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/parasail_test.go b/tests/core-providers/parasail_test.go index 931c527bc..e8f03ee4e 100644 --- a/tests/core-providers/parasail_test.go +++ b/tests/core-providers/parasail_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,16 +10,19 @@ import ( ) func TestParasail(t *testing.T) { + if os.Getenv("PARASAIL_API_KEY") == "" { + t.Skip("Skipping Parasail tests because PARASAIL_API_KEY is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ Provider: schemas.Parasail, - ChatModel: "parasail-deepseek-r1", + ChatModel: "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8", TextModel: "", // Parasail doesn't support text completion EmbeddingModel: "", // Parasail doesn't support embedding Scenarios: config.TestScenarios{ @@ -38,5 +42,8 @@ func TestParasail(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("ParasailTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/scenarios/automatic_function_calling.go b/tests/core-providers/scenarios/automatic_function_calling.go index ffc29b338..fcac87c4e 100644 --- a/tests/core-providers/scenarios/automatic_function_calling.go +++ b/tests/core-providers/scenarios/automatic_function_calling.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" @@ -19,6 +20,10 @@ func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx } t.Run("AutomaticFunctionCalling", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + chatMessages := []schemas.ChatMessage{ CreateBasicChatMessage("Get the current time in UTC timezone"), } @@ -79,6 +84,7 @@ func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx }, }, }, + Fallbacks: testConfig.Fallbacks, } return client.ChatCompletionRequest(ctx, chatReq) @@ -100,6 +106,7 @@ func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx }, }, }, + Fallbacks: testConfig.Fallbacks, } return client.ResponsesRequest(ctx, responsesReq) diff --git a/tests/core-providers/scenarios/chat_completion_stream.go b/tests/core-providers/scenarios/chat_completion_stream.go index 8023b3b6b..68f41640a 100644 --- a/tests/core-providers/scenarios/chat_completion_stream.go +++ b/tests/core-providers/scenarios/chat_completion_stream.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" "time" @@ -20,8 +21,12 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont } t.Run("ChatCompletionStream", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + messages := []schemas.ChatMessage{ - CreateBasicChatMessage("Tell me a short story about a robot learning to paint the city which has the eiffel tower. Keep it under 200 words."), + CreateBasicChatMessage("Tell me a short story about a robot learning to paint the city which has the eiffel tower. Keep it under 200 words and include the city's name."), } request := &schemas.BifrostChatRequest{ @@ -83,7 +88,7 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont if response == nil { t.Fatal("Streaming response should not be nil") } - lastResponse = response + lastResponse = DeepCopyBifrostStream(response) // Basic validation of streaming response structure if response.BifrostChatResponse != nil { @@ -94,6 +99,9 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont t.Logf("⚠️ Warning: Response ID is empty") } + // Log latency for each chunk (can be 0 for inter-chunks) + t.Logf("πŸ“Š Chunk %d latency: %d ms", responseCount+1, response.BifrostChatResponse.ExtraFields.Latency) + // Process each choice in the response for _, choice := range response.BifrostChatResponse.Choices { // Validate that this is a stream response @@ -106,7 +114,7 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont } // Get content from delta - if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { + if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { delta := choice.ChatStreamResponseChoice.Delta if delta.Content != nil { fullContent.WriteString(*delta.Content) @@ -172,14 +180,15 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont if len(lastResponse.BifrostChatResponse.Choices) > 0 && lastResponse.BifrostChatResponse.Choices[0].FinishReason != nil { consolidatedResponse.Choices[0].FinishReason = lastResponse.BifrostChatResponse.Choices[0].FinishReason } + consolidatedResponse.ExtraFields.Latency = lastResponse.BifrostChatResponse.ExtraFields.Latency } // Enhanced validation expectations for streaming expectations := GetExpectationsForScenario("ChatCompletionStream", testConfig, map[string]interface{}{}) expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) - expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, []string{"paris"}...) // Should include story elements - expectations.MinContentLength = 50 // Should be substantial story - expectations.MaxContentLength = 2000 // Reasonable upper bound + expectations.ShouldContainAnyOf = append(expectations.ShouldContainAnyOf, []string{"paris"}...) // Should include story elements + expectations.MinContentLength = 50 // Should be substantial story + expectations.MaxContentLength = 2000 // Reasonable upper bound // Validate the consolidated streaming response validationResult := ValidateChatResponse(t, consolidatedResponse, nil, expectations, "ChatCompletionStream") @@ -198,7 +207,7 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont } if !validationResult.Passed { - t.Logf("⚠️ Streaming validation warnings: %v", validationResult.Errors) + t.Errorf("❌ Streaming validation failed: %v", validationResult.Errors) } t.Logf("πŸ“Š Streaming metrics: %d chunks, %d chars", responseCount, len(finalContent)) @@ -210,8 +219,12 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont // Test streaming with tool calls if supported if testConfig.Scenarios.ToolCalls { t.Run("ChatCompletionStreamWithTools", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + messages := []schemas.ChatMessage{ - CreateBasicChatMessage("What's the weather like in San Francisco? Please use the get_weather function."), + CreateBasicChatMessage("What's the weather like in San Francisco in celsius? Please use the get_weather function."), } tool := GetSampleChatTool(SampleToolTypeWeather) diff --git a/tests/core-providers/scenarios/complete_end_to_end.go b/tests/core-providers/scenarios/complete_end_to_end.go index 1137b8921..fc2db4acb 100644 --- a/tests/core-providers/scenarios/complete_end_to_end.go +++ b/tests/core-providers/scenarios/complete_end_to_end.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" @@ -19,6 +20,10 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C } t.Run("CompleteEnd2End", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // ============================================================================= // STEP 1: Multi-step conversation with tools - Test both APIs in parallel // ============================================================================= @@ -129,9 +134,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Add all output messages to Responses API conversation history if result1.ResponsesAPIResponse != nil && result1.ResponsesAPIResponse.Output != nil { - for _, output := range result1.ResponsesAPIResponse.Output { - responsesConversationHistory = append(responsesConversationHistory, output) - } + responsesConversationHistory = append(responsesConversationHistory, result1.ResponsesAPIResponse.Output...) } // Extract tool calls from both APIs @@ -164,7 +167,104 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C } // ============================================================================= - // STEP 2: Continue with follow-up (multimodal if supported) - Test both APIs + // STEP 2: Send this tool call result to the model again + // ============================================================================= + + // Use retry framework for step 2 (processing tool results) + retryConfig2 := GetTestRetryConfigForScenario("CompleteEnd2End_ToolResult", testConfig) + retryContext2 := TestRetryContext{ + ScenarioName: "CompleteEnd2End_Step2", + ExpectedBehavior: map[string]interface{}{ + "process_tool_result": true, + "acknowledge_weather": true, + "continue_conversation": true, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + "step": "process_tool_result", + "scenario": "complete_end_to_end", + "chat_conversation_length": len(chatConversationHistory), + "responses_conversation_length": len(responsesConversationHistory), + }, + } + + // Enhanced validation for step 2 - should acknowledge tool results + expectations2 := ConversationExpectations([]string{"weather", "temperature"}) + expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider) + expectations2.MinContentLength = 15 // Should provide meaningful response to tool result + expectations2.MaxContentLength = 500 // Reasonable upper bound for tool result processing + expectations2.ShouldNotContainWords = []string{ + "cannot help", "don't understand", "no information", + "unable to process", "invalid tool result", + } // Should not indicate confusion about tool results + + // Create operations for both APIs - Step 2 (processing tool results) + chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + chatReq := &schemas.BifrostChatRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: chatConversationHistory, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(200), + }, + Fallbacks: testConfig.Fallbacks, + } + return client.ChatCompletionRequest(ctx, chatReq) + } + + responsesOperation2 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + responsesReq := &schemas.BifrostResponsesRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: responsesConversationHistory, + Params: &schemas.ResponsesParameters{ + MaxOutputTokens: bifrost.Ptr(200), + }, + } + return client.ResponsesRequest(ctx, responsesReq) + } + + // Execute dual API test for Step 2 (processing tool results) + result2 := WithDualAPITestRetry(t, + retryConfig2, + retryContext2, + expectations2, + "CompleteEnd2End_Step2", + chatOperation2, + responsesOperation2) + + // Validate both APIs succeeded + if !result2.BothSucceeded { + var errors []string + if result2.ChatCompletionsError != nil { + errors = append(errors, "Chat Completions: "+GetErrorMessage(result2.ChatCompletionsError)) + } + if result2.ResponsesAPIError != nil { + errors = append(errors, "Responses API: "+GetErrorMessage(result2.ResponsesAPIError)) + } + if len(errors) == 0 { + errors = append(errors, "One or both APIs failed validation (see logs above)") + } + t.Fatalf("❌ CompleteEnd2End_Step2 dual API test failed: %v", errors) + } + + t.Logf("βœ… Chat Completions API tool result response: %s", GetChatContent(result2.ChatCompletionsResponse)) + t.Logf("βœ… Responses API tool result response: %s", GetResponsesContent(result2.ResponsesAPIResponse)) + + // Add Step 2 responses to conversation histories for Step 3 + if result2.ChatCompletionsResponse.Choices != nil { + for _, choice := range result2.ChatCompletionsResponse.Choices { + chatConversationHistory = append(chatConversationHistory, *choice.Message) + } + } + + if result2.ResponsesAPIResponse != nil && result2.ResponsesAPIResponse.Output != nil { + responsesConversationHistory = append(responsesConversationHistory, result2.ResponsesAPIResponse.Output...) + } + + // ============================================================================= + // STEP 3: Continue with follow-up (multimodal if supported) - Test both APIs // ============================================================================= // Determine if we're doing a vision step @@ -191,15 +291,15 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C } // Use appropriate retry config for final step - var retryConfig2 TestRetryConfig - var expectations2 ResponseExpectations + var retryConfig3 TestRetryConfig + var expectations3 ResponseExpectations if isVisionStep { - retryConfig2 = GetTestRetryConfigForScenario("CompleteEnd2End_Vision", testConfig) - expectations2 = VisionExpectations([]string{"paris", "river"}) + retryConfig3 = GetTestRetryConfigForScenario("CompleteEnd2End_Vision", testConfig) + expectations3 = VisionExpectations([]string{"paris", "river"}) } else { - retryConfig2 = GetTestRetryConfigForScenario("CompleteEnd2End_Chat", testConfig) - expectations2 = ConversationExpectations([]string{"paris", "cloudy"}) + retryConfig3 = GetTestRetryConfigForScenario("CompleteEnd2End_Chat", testConfig) + expectations3 = ConversationExpectations([]string{"paris", "cloudy"}) } // Prepare expected keywords to match expectations exactly @@ -210,8 +310,8 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C expectedKeywords = []string{"paris", "cloudy"} // Must match ConversationExpectations exactly } - retryContext2 := TestRetryContext{ - ScenarioName: "CompleteEnd2End_Step2", + retryContext3 := TestRetryContext{ + ScenarioName: "CompleteEnd2End_Step3", ExpectedBehavior: map[string]interface{}{ "continue_conversation": true, "acknowledge_context": true, @@ -229,16 +329,16 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C } // Enhanced validation for final response - expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider) - expectations2.MinContentLength = 20 // Should provide some meaningful response - expectations2.MaxContentLength = 800 // End-to-end can be verbose - expectations2.ShouldNotContainWords = []string{ + expectations3 = ModifyExpectationsForProvider(expectations3, testConfig.Provider) + expectations3.MinContentLength = 20 // Should provide some meaningful response + expectations3.MaxContentLength = 800 // End-to-end can be verbose + expectations3.ShouldNotContainWords = []string{ "cannot help", "don't understand", "confused", "start over", "reset conversation", } // Context loss indicators - // Create operations for both APIs - Step 2 - chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + // Create operations for both APIs - Step 3 + chatOperation3 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: model, @@ -251,7 +351,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C return client.ChatCompletionRequest(ctx, chatReq) } - responsesOperation2 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + responsesOperation3 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: model, @@ -263,33 +363,33 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C return client.ResponsesRequest(ctx, responsesReq) } - // Execute dual API test for Step 2 - result2 := WithDualAPITestRetry(t, - retryConfig2, - retryContext2, - expectations2, - "CompleteEnd2End_Step2", - chatOperation2, - responsesOperation2) + // Execute dual API test for Step 3 + result3 := WithDualAPITestRetry(t, + retryConfig3, + retryContext3, + expectations3, + "CompleteEnd2End_Step3", + chatOperation3, + responsesOperation3) // Validate both APIs succeeded - if !result2.BothSucceeded { + if !result3.BothSucceeded { var errors []string - if result2.ChatCompletionsError != nil { - errors = append(errors, "Chat Completions: "+GetErrorMessage(result2.ChatCompletionsError)) + if result3.ChatCompletionsError != nil { + errors = append(errors, "Chat Completions: "+GetErrorMessage(result3.ChatCompletionsError)) } - if result2.ResponsesAPIError != nil { - errors = append(errors, "Responses API: "+GetErrorMessage(result2.ResponsesAPIError)) + if result3.ResponsesAPIError != nil { + errors = append(errors, "Responses API: "+GetErrorMessage(result3.ResponsesAPIError)) } if len(errors) == 0 { errors = append(errors, "One or both APIs failed validation (see logs above)") } - t.Fatalf("❌ CompleteEnd2End_Step2 dual API test failed: %v", errors) + t.Fatalf("❌ CompleteEnd2End_Step3 dual API test failed: %v", errors) } // Log and validate results from both APIs - if result2.ChatCompletionsResponse != nil { - chatFinalContent := GetChatContent(result2.ChatCompletionsResponse) + if result3.ChatCompletionsResponse != nil { + chatFinalContent := GetChatContent(result3.ChatCompletionsResponse) // Additional validation for conversation context if len(chatToolCalls) > 0 && strings.Contains(strings.ToLower(chatFinalContent), "weather") { @@ -303,8 +403,8 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C t.Logf("βœ… Chat Completions API final result: %s", chatFinalContent) } - if result2.ResponsesAPIResponse != nil { - responsesFinalContent := GetResponsesContent(result2.ResponsesAPIResponse) + if result3.ResponsesAPIResponse != nil { + responsesFinalContent := GetResponsesContent(result3.ResponsesAPIResponse) // Additional validation for conversation context if len(responsesToolCalls) > 0 && strings.Contains(strings.ToLower(responsesFinalContent), "weather") { diff --git a/tests/core-providers/scenarios/cross_provider_scenarios.go b/tests/core-providers/scenarios/cross_provider_scenarios.go index 55b2beb97..70c5504c1 100644 --- a/tests/core-providers/scenarios/cross_provider_scenarios.go +++ b/tests/core-providers/scenarios/cross_provider_scenarios.go @@ -697,7 +697,7 @@ func RunCrossProviderScenarioTest(t *testing.T, client *bifrost.Bifrost, ctx con responseContent = GetResponsesContent(response.ResponsesResponse) } else { if response.ChatResponse != nil { - // Use Chat API choices + // Use Chat API choices for _, choice := range response.ChatResponse.Choices { if choice.Message != nil { conversationHistory = append(conversationHistory, *choice.Message) @@ -798,7 +798,6 @@ func RunCrossProviderConsistencyTest(t *testing.T, client *bifrost.Bifrost, ctx t.Logf("Testing %s...", provider.Provider) - var err *schemas.BifrostError var content string if useResponsesAPI { @@ -841,11 +840,6 @@ func RunCrossProviderConsistencyTest(t *testing.T, client *bifrost.Bifrost, ctx content = GetChatContent(chatResponse) } - if err != nil { - t.Logf("❌ %s failed: %v", provider.Provider, GetErrorMessage(err)) - continue - } - sentences := strings.Split(strings.TrimSpace(content), ".") result := ConsistencyResult{ diff --git a/tests/core-providers/scenarios/embedding.go b/tests/core-providers/scenarios/embedding.go index ac2bf1eee..2949dbef0 100644 --- a/tests/core-providers/scenarios/embedding.go +++ b/tests/core-providers/scenarios/embedding.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math" + "os" "strings" "testing" @@ -48,6 +49,10 @@ func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } t.Run("Embedding", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Test texts with expected semantic relationships testTexts := []string{ "Hello, world!", @@ -64,7 +69,7 @@ func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context Params: &schemas.EmbeddingParameters{ EncodingFormat: bifrost.Ptr("float"), }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.EmbeddingFallbacks, } // Enhanced embedding validation @@ -96,11 +101,17 @@ func validateEmbeddingSemantics(t *testing.T, response *schemas.BifrostEmbedding // Extract and validate embeddings embeddings := make([][]float32, len(testTexts)) - if len(response.Data) != len(testTexts) { - t.Fatalf("Expected %d embedding results, got %d", len(testTexts), len(response.Data)) + responseDataLength := len(response.Data) + if responseDataLength != len(testTexts) { + if responseDataLength > 0 && response.Data[0].Embedding.Embedding2DArray != nil { + responseDataLength = len(response.Data[0].Embedding.Embedding2DArray) + } + if responseDataLength != len(testTexts) { + t.Fatalf("Expected %d embedding results, got %d", len(testTexts), responseDataLength) + } } - for i := range response.Data { + for i := range responseDataLength { vec, extractErr := getEmbeddingVector(response.Data[i]) if extractErr != nil { t.Fatalf("Failed to extract embedding vector for text '%s': %v", testTexts[i], extractErr) diff --git a/tests/core-providers/scenarios/end_to_end_tool_calling.go b/tests/core-providers/scenarios/end_to_end_tool_calling.go index fb566cdab..a0f304b9b 100644 --- a/tests/core-providers/scenarios/end_to_end_tool_calling.go +++ b/tests/core-providers/scenarios/end_to_end_tool_calling.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" @@ -19,6 +20,10 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex } t.Run("End2EndToolCalling", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // ============================================================================= // STEP 1: User asks for weather - Test both APIs in parallel // ============================================================================= @@ -164,11 +169,11 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex } // Enhanced validation for final response - expectations2 := ConversationExpectations([]string{"san francisco", "22", "sunny"}) + expectations2 := ConversationExpectations([]string{"francisco", "22", "sunny"}) expectations2 = ModifyExpectationsForProvider(expectations2, testConfig.Provider) - expectations2.ShouldContainKeywords = []string{"san francisco", "22", "sunny"} // Should reference tool results - expectations2.ShouldNotContainWords = []string{"error", "failed", "cannot"} // Should not contain error terms - expectations2.MinContentLength = 30 // Should be a substantial response + expectations2.ShouldContainKeywords = []string{"francisco", "22", "sunny"} // Should reference tool results (using "francisco" to match both "San Francisco" and "san francisco") + expectations2.ShouldNotContainWords = []string{"error", "failed", "cannot"} // Should not contain error terms + expectations2.MinContentLength = 30 // Should be a substantial response // Create operations for both APIs - Step 2 chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { diff --git a/tests/core-providers/scenarios/image_base64.go b/tests/core-providers/scenarios/image_base64.go index d8b12fec0..025af5fd7 100644 --- a/tests/core-providers/scenarios/image_base64.go +++ b/tests/core-providers/scenarios/image_base64.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" @@ -19,6 +20,10 @@ func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Conte } t.Run("ImageBase64", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Load lion base64 image for testing lionBase64, err := GetLionBase64Image() if err != nil { @@ -142,7 +147,7 @@ func validateBase64ImageContent(t *testing.T, content string, apiName string) { lowerContent := strings.ToLower(content) foundAnimal := strings.Contains(lowerContent, "lion") || strings.Contains(lowerContent, "animal") || strings.Contains(lowerContent, "cat") || strings.Contains(lowerContent, "feline") - + if len(content) < 10 { t.Logf("⚠️ %s response seems quite short for image description: %s", apiName, content) } else if foundAnimal { diff --git a/tests/core-providers/scenarios/image_url.go b/tests/core-providers/scenarios/image_url.go index 591c22d9f..7c3abd201 100644 --- a/tests/core-providers/scenarios/image_url.go +++ b/tests/core-providers/scenarios/image_url.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" @@ -19,6 +20,10 @@ func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, } t.Run("ImageURL", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Create messages for both APIs using the isResponsesAPI flag chatMessages := []schemas.ChatMessage{ CreateImageChatMessage("What do you see in this image?", TestImageURL), diff --git a/tests/core-providers/scenarios/multi_turn_conversation.go b/tests/core-providers/scenarios/multi_turn_conversation.go index 9fc75dd66..ca013b9f6 100644 --- a/tests/core-providers/scenarios/multi_turn_conversation.go +++ b/tests/core-providers/scenarios/multi_turn_conversation.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" @@ -19,6 +20,10 @@ func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx con } t.Run("MultiTurnConversation", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // First message - introduction userMessage1 := CreateBasicChatMessage("Hello, my name is Alice.") messages1 := []schemas.ChatMessage{ diff --git a/tests/core-providers/scenarios/multiple_images.go b/tests/core-providers/scenarios/multiple_images.go index 5165c4b03..841f21d7b 100644 --- a/tests/core-providers/scenarios/multiple_images.go +++ b/tests/core-providers/scenarios/multiple_images.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" @@ -19,6 +20,10 @@ func RunMultipleImagesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co } t.Run("MultipleImages", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Load lion base64 image for comparison lionBase64, err := GetLionBase64Image() if err != nil { diff --git a/tests/core-providers/scenarios/multiple_tool_calls.go b/tests/core-providers/scenarios/multiple_tool_calls.go index a52787856..be20090bd 100644 --- a/tests/core-providers/scenarios/multiple_tool_calls.go +++ b/tests/core-providers/scenarios/multiple_tool_calls.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -27,6 +28,10 @@ func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context } t.Run("MultipleToolCalls", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + chatMessages := []schemas.ChatMessage{ CreateBasicChatMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both?"), } @@ -77,6 +82,7 @@ func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context Params: &schemas.ChatParameters{ Tools: []schemas.ChatTool{*chatWeatherTool, *chatCalculatorTool}, }, + Fallbacks: testConfig.Fallbacks, } chatReq.Input = chatMessages return client.ChatCompletionRequest(ctx, chatReq) @@ -89,6 +95,7 @@ func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context Params: &schemas.ResponsesParameters{ Tools: []schemas.ResponsesTool{*responsesWeatherTool, *responsesCalculatorTool}, }, + Fallbacks: testConfig.Fallbacks, } responsesReq.Input = responsesMessages return client.ResponsesRequest(ctx, responsesReq) diff --git a/tests/core-providers/scenarios/reasoning.go b/tests/core-providers/scenarios/reasoning.go index 797f79c8a..c35322fcb 100644 --- a/tests/core-providers/scenarios/reasoning.go +++ b/tests/core-providers/scenarios/reasoning.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -24,6 +25,10 @@ func RunReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } t.Run("Reasoning", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Create a complex problem that requires step-by-step reasoning problemPrompt := "A farmer has 100 chickens and 50 cows. Each chicken lays 5 eggs per week, and each cow produces 20 liters of milk per day. If the farmer sells eggs for $0.25 each and milk for $1.50 per liter, and it costs $2 per week to feed each chicken and $15 per week to feed each cow, what is the farmer's weekly profit? Please show your step-by-step reasoning." @@ -46,6 +51,7 @@ func RunReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context // Include reasoning content in response Include: []string{"reasoning.encrypted_content"}, }, + Fallbacks: testConfig.Fallbacks, } // Use retry framework with enhanced validation for reasoning @@ -54,7 +60,6 @@ func RunReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context ScenarioName: "Reasoning", ExpectedBehavior: map[string]interface{}{ "should_show_reasoning": true, - "should_calculate": true, "mathematical_problem": true, "step_by_step": true, }, @@ -83,9 +88,8 @@ func RunReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) expectations.MinContentLength = 50 // Reasoning requires substantial content expectations.MaxContentLength = 2000 // Reasoning can be verbose - expectations.ShouldContainKeywords = []string{"weekly", "profit", "$"} expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{ - "cannot solve", "unable to calculate", "need more information", + "cannot solve", "unable to calculate", }...) response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "Reasoning", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { diff --git a/tests/core-providers/scenarios/response_validation.go b/tests/core-providers/scenarios/response_validation.go index 729971cb2..408aa5501 100644 --- a/tests/core-providers/scenarios/response_validation.go +++ b/tests/core-providers/scenarios/response_validation.go @@ -38,6 +38,7 @@ type ResponseExpectations struct { ShouldHaveUsageStats bool // Should have token usage information ShouldHaveTimestamps bool // Should have created timestamp ShouldHaveModel bool // Should have model field + ShouldHaveLatency bool // Should have latency information in ExtraFields // Provider-specific expectations ProviderSpecific map[string]interface{} // Provider-specific validation data @@ -488,6 +489,16 @@ func validateChatTechnicalFields(t *testing.T, response *schemas.BifrostChatResp result.Warnings = append(result.Warnings, "Expected model field but not present or empty") } } + + // Check latency field + if expectations.ShouldHaveLatency { + if response.ExtraFields.Latency <= 0 { + result.Passed = false + result.Errors = append(result.Errors, "Expected latency information but not present or invalid") + } else { + result.MetricsCollected["latency_ms"] = response.ExtraFields.Latency + } + } } // collectChatResponseMetrics collects metrics from the chat response for analysis @@ -651,6 +662,16 @@ func validateTextCompletionTechnicalFields(t *testing.T, response *schemas.Bifro result.Warnings = append(result.Warnings, "Expected model field but not present or empty") } } + + // Check latency field + if expectations.ShouldHaveLatency { + if response.ExtraFields.Latency <= 0 { + result.Passed = false + result.Errors = append(result.Errors, "Expected latency information but not present or invalid") + } else { + result.MetricsCollected["latency_ms"] = response.ExtraFields.Latency + } + } } // collectTextCompletionResponseMetrics collects metrics from the text completion response for analysis @@ -818,6 +839,16 @@ func validateResponsesTechnicalFields(t *testing.T, response *schemas.BifrostRes result.Warnings = append(result.Warnings, "Expected created timestamp but not present") } } + + // Check latency field + if expectations.ShouldHaveLatency { + if response.ExtraFields.Latency <= 0 { + result.Passed = false + result.Errors = append(result.Errors, "Expected latency information but not present or invalid") + } else { + result.MetricsCollected["latency_ms"] = response.ExtraFields.Latency + } + } } // collectResponsesResponseMetrics collects metrics from the Responses API response for analysis @@ -875,6 +906,16 @@ func validateSpeechSynthesisResponse(t *testing.T, response *schemas.BifrostSpee result.MetricsCollected["expected_audio_format"] = expectedFormat } + // Check latency field + if expectations.ShouldHaveLatency { + if response.ExtraFields.Latency <= 0 { + result.Passed = false + result.Errors = append(result.Errors, "Expected latency information but not present or invalid") + } else { + result.MetricsCollected["latency_ms"] = response.ExtraFields.Latency + } + } + result.MetricsCollected["speech_validation"] = "completed" } @@ -930,6 +971,16 @@ func validateTranscriptionFields(t *testing.T, response *schemas.BifrostTranscri result.MetricsCollected["audio_duration"] = *response.Duration } + // Check latency field + if expectations.ShouldHaveLatency { + if response.ExtraFields.Latency <= 0 { + result.Passed = false + result.Errors = append(result.Errors, "Expected latency information but not present or invalid") + } else { + result.MetricsCollected["latency_ms"] = response.ExtraFields.Latency + } + } + result.MetricsCollected["transcription_validation"] = "completed" } @@ -948,7 +999,7 @@ func collectTranscriptionResponseMetrics(response *schemas.BifrostTranscriptionR // validateEmbeddingFields validates embedding responses func validateEmbeddingFields(t *testing.T, response *schemas.BifrostEmbeddingResponse, expectations ResponseExpectations, result *ValidationResult) { // Check if response has embedding data - if response.Data == nil || len(response.Data) == 0 { + if len(response.Data) == 0 { result.Passed = false result.Errors = append(result.Errors, "Embedding response missing data") return @@ -973,6 +1024,16 @@ func validateEmbeddingFields(t *testing.T, response *schemas.BifrostEmbeddingRes } } + // Check latency field + if expectations.ShouldHaveLatency { + if response.ExtraFields.Latency <= 0 { + result.Passed = false + result.Errors = append(result.Errors, "Expected latency information but not present or invalid") + } else { + result.MetricsCollected["latency_ms"] = response.ExtraFields.Latency + } + } + result.MetricsCollected["embedding_validation"] = "completed" } @@ -981,11 +1042,11 @@ func collectEmbeddingResponseMetrics(response *schemas.BifrostEmbeddingResponse, result.MetricsCollected["has_data"] = response.Data != nil result.MetricsCollected["embedding_count"] = len(response.Data) result.MetricsCollected["has_usage"] = response.Usage != nil - if response.Data != nil && len(response.Data) > 0 { + if len(response.Data) > 0 { var dimensions int if response.Data[0].Embedding.EmbeddingArray != nil { dimensions = len(response.Data[0].Embedding.EmbeddingArray) - } else if response.Data[0].Embedding.Embedding2DArray != nil && len(response.Data[0].Embedding.Embedding2DArray) > 0 { + } else if len(response.Data[0].Embedding.Embedding2DArray) > 0 { dimensions = len(response.Data[0].Embedding.Embedding2DArray[0]) } result.MetricsCollected["embedding_dimensions"] = dimensions @@ -1224,7 +1285,7 @@ func logValidationResults(t *testing.T, result ValidationResult, scenarioName st if result.Passed { t.Logf("βœ… Validation passed for %s", scenarioName) } else { - t.Logf("❌ Validation failed for %s with %d errors", scenarioName, len(result.Errors)) + t.Errorf("❌ Validation failed for %s with %d errors", scenarioName, len(result.Errors)) for _, err := range result.Errors { t.Logf(" Error: %s", err) } diff --git a/tests/core-providers/scenarios/responses_stream.go b/tests/core-providers/scenarios/responses_stream.go index 3404a1772..eef7f73ae 100644 --- a/tests/core-providers/scenarios/responses_stream.go +++ b/tests/core-providers/scenarios/responses_stream.go @@ -3,6 +3,7 @@ package scenarios import ( "context" "fmt" + "os" "strings" "testing" "time" @@ -21,6 +22,10 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C } t.Run("ResponsesStream", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + messages := []schemas.ResponsesMessage{ { Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), @@ -97,7 +102,7 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C if response == nil { t.Fatal("Streaming response should not be nil") } - lastResponse = response + lastResponse = DeepCopyBifrostStream(response) // Basic validation of streaming response structure if response.BifrostResponsesStreamResponse != nil { @@ -105,6 +110,9 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C t.Logf("⚠️ Warning: Provider mismatch - expected %s, got %s", testConfig.Provider, response.BifrostResponsesStreamResponse.ExtraFields.Provider) } + // Log latency for each chunk (can be 0 for inter-chunks) + t.Logf("πŸ“Š Chunk %d latency: %d ms", responseCount+1, response.BifrostResponsesStreamResponse.ExtraFields.Latency) + // Process the streaming response streamResp := response.BifrostResponsesStreamResponse @@ -229,11 +237,15 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Test responses streaming with tool calls if supported if testConfig.Scenarios.ToolCalls { t.Run("ResponsesStreamWithTools", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + messages := []schemas.ResponsesMessage{ { Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), Content: &schemas.ResponsesMessageContent{ - ContentStr: schemas.Ptr("What's the weather like in San Francisco? Please use the get_weather function."), + ContentStr: schemas.Ptr("What's the weather like in San Francisco in celsius? Please use the get_weather function."), }, }, } @@ -355,6 +367,10 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Test responses streaming with reasoning if supported if testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != "" { t.Run("ResponsesStreamWithReasoning", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + problemPrompt := "Solve this step by step: If a train leaves station A at 2 PM traveling at 60 mph, and another train leaves station B at 3 PM traveling at 80 mph toward station A, and the stations are 420 miles apart, when will they meet?" messages := []schemas.ResponsesMessage{ @@ -503,7 +519,6 @@ func validateResponsesStreamingStructure(t *testing.T, eventTypes map[schemas.Re } } - // StreamingValidationResult represents the result of streaming validation type StreamingValidationResult struct { Passed bool @@ -593,6 +608,15 @@ func validateResponsesStreamingResponse(t *testing.T, eventTypes map[schemas.Res } } + // Validate latency is present in the last chunk (total latency) + if lastResponse != nil && lastResponse.BifrostResponsesStreamResponse != nil { + if lastResponse.BifrostResponsesStreamResponse.ExtraFields.Latency <= 0 { + errors = append(errors, fmt.Sprintf("Last streaming chunk missing latency information (got %d ms)", lastResponse.BifrostResponsesStreamResponse.ExtraFields.Latency)) + } else { + t.Logf("βœ… Total streaming latency: %d ms", lastResponse.BifrostResponsesStreamResponse.ExtraFields.Latency) + } + } + return StreamingValidationResult{ Passed: len(errors) == 0, Errors: errors, diff --git a/tests/core-providers/scenarios/simple_chat.go b/tests/core-providers/scenarios/simple_chat.go index b8a8d2acc..17d601b82 100644 --- a/tests/core-providers/scenarios/simple_chat.go +++ b/tests/core-providers/scenarios/simple_chat.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -18,6 +19,10 @@ func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex } t.Run("SimpleChat", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + chatMessages := []schemas.ChatMessage{ CreateBasicChatMessage("Hello! What's the capital of France?"), } @@ -74,6 +79,7 @@ func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex Params: &schemas.ChatParameters{ MaxCompletionTokens: bifrost.Ptr(150), }, + Fallbacks: testConfig.Fallbacks, } response, err := client.ChatCompletionRequest(ctx, chatReq) if err != nil { @@ -101,6 +107,7 @@ func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex Params: &schemas.ResponsesParameters{ MaxOutputTokens: bifrost.Ptr(150), }, + Fallbacks: testConfig.Fallbacks, } response, err := client.ResponsesRequest(ctx, responsesReq) if err != nil { diff --git a/tests/core-providers/scenarios/speech_synthesis.go b/tests/core-providers/scenarios/speech_synthesis.go index 530e3bff0..6c1cc7b45 100644 --- a/tests/core-providers/scenarios/speech_synthesis.go +++ b/tests/core-providers/scenarios/speech_synthesis.go @@ -21,6 +21,10 @@ func RunSpeechSynthesisTest(t *testing.T, client *bifrost.Bifrost, ctx context.C } t.Run("SpeechSynthesis", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Test with shared text constants for round-trip validation with transcription testCases := []struct { name string @@ -58,6 +62,10 @@ func RunSpeechSynthesisTest(t *testing.T, client *bifrost.Bifrost, ctx context.C for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + voice := GetProviderVoice(testConfig.Provider, tc.voiceType) request := &schemas.BifrostSpeechRequest{ Provider: testConfig.Provider, @@ -71,14 +79,16 @@ func RunSpeechSynthesisTest(t *testing.T, client *bifrost.Bifrost, ctx context.C }, ResponseFormat: tc.format, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.SpeechSynthesisFallbacks, } // Enhanced validation for speech synthesis expectations := SpeechExpectations(tc.expectMinBytes) expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) - speechResponse, bifrostErr := client.SpeechRequest(ctx, request) + requestCtx := context.Background() + + speechResponse, bifrostErr := client.SpeechRequest(requestCtx, request) if bifrostErr != nil { t.Fatalf("❌ SpeechSynthesis_"+tc.name+" request failed: %v", GetErrorMessage(bifrostErr)) } @@ -123,7 +133,15 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c } t.Run("SpeechSynthesisAdvanced", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + t.Run("LongText_HDModel", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Test with longer text and HD model longText := ` This is a comprehensive test of the text-to-speech functionality using a longer piece of text. @@ -146,7 +164,7 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c ResponseFormat: "mp3", Instructions: "Speak slowly and clearly with natural intonation.", }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.SpeechSynthesisFallbacks, } retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisHD", testConfig) @@ -167,8 +185,10 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c expectations := SpeechExpectations(5000) // HD should produce substantial audio expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) + requestCtx := context.Background() + response, bifrostErr := WithTestRetry(t, retryConfig, retryContext, expectations, "SpeechSynthesis_HD", func() (*schemas.BifrostResponse, *schemas.BifrostError) { - c, err := client.SpeechRequest(ctx, request) + c, err := client.SpeechRequest(requestCtx, request) if err != nil { return nil, err } @@ -195,12 +215,20 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c }) t.Run("AllVoiceOptions", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Test provider-specific voice options voiceTypes := []string{"primary", "secondary", "tertiary"} testText := TTSTestTextBasic // Use shared constant for _, voiceType := range voiceTypes { t.Run("VoiceType_"+voiceType, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + voice := GetProviderVoice(testConfig.Provider, voiceType) request := &schemas.BifrostSpeechRequest{ Provider: testConfig.Provider, @@ -214,13 +242,15 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c }, ResponseFormat: "mp3", }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.SpeechSynthesisFallbacks, } expectations := SpeechExpectations(500) expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) - speechResponse, bifrostErr := client.SpeechRequest(ctx, request) + requestCtx := context.Background() + + speechResponse, bifrostErr := client.SpeechRequest(requestCtx, request) if bifrostErr != nil { t.Fatalf("❌ SpeechSynthesis_Voice_"+voiceType+" request failed: %v", GetErrorMessage(bifrostErr)) } @@ -255,7 +285,7 @@ func validateSpeechSynthesisSpecific(t *testing.T, response *schemas.BifrostSpee if audioSize < expectMinBytes { t.Fatalf("Audio data too small: got %d bytes, expected at least %d", audioSize, expectMinBytes) } - + if expectedModel != "" && response.ExtraFields.ModelRequested != expectedModel { t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.ModelRequested) } diff --git a/tests/core-providers/scenarios/speech_synthesis_stream.go b/tests/core-providers/scenarios/speech_synthesis_stream.go index 3d8fe634b..334bbf204 100644 --- a/tests/core-providers/scenarios/speech_synthesis_stream.go +++ b/tests/core-providers/scenarios/speech_synthesis_stream.go @@ -3,9 +3,9 @@ package scenarios import ( "context" "fmt" + "os" "strings" "testing" - "time" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -21,6 +21,10 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con } t.Run("SpeechSynthesisStream", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Test streaming with different text lengths testCases := []struct { name string @@ -29,6 +33,7 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con format string expectMinChunks int expectMinBytes int + skip bool }{ { name: "ShortText_Streaming", @@ -37,6 +42,7 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con format: "mp3", expectMinChunks: 1, expectMinBytes: 1000, + skip: false, }, { name: "LongText_Streaming", @@ -48,6 +54,7 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con format: "mp3", expectMinChunks: 2, expectMinBytes: 3000, + skip: testConfig.Provider == schemas.Gemini, }, { name: "MediumText_Echo_WAV", @@ -56,11 +63,21 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con format: "wav", expectMinChunks: 1, expectMinBytes: 2000, + skip: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + if tc.skip { + t.Skipf("Skipping %s test", tc.name) + return + } + voice := tc.voice request := &schemas.BifrostSpeechRequest{ Provider: testConfig.Provider, @@ -74,7 +91,7 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con }, ResponseFormat: tc.format, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.SpeechSynthesisFallbacks, } // Use retry framework for streaming speech synthesis @@ -97,8 +114,10 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con }, } + requestCtx := context.Background() + responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.SpeechStreamRequest(ctx, request) + return client.SpeechStreamRequest(requestCtx, request) }) // Enhanced validation for streaming speech synthesis @@ -113,69 +132,61 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con var chunkCount int var lastResponse *schemas.BifrostStream var streamErrors []string - - streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() + var lastTokenLatency int64 // Read streaming chunks with enhanced validation - for { - select { - case response, ok := <-responseChannel: - if !ok { - // Channel closed, streaming complete - goto streamComplete - } + for response := range responseChannel { + if response == nil { + streamErrors = append(streamErrors, "Received nil stream response") + continue + } - if response == nil { - streamErrors = append(streamErrors, "Received nil stream response") - continue - } + // Check for errors in stream + if response.BifrostError != nil { + streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError))) + continue + } - // Check for errors in stream - if response.BifrostError != nil { - streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError))) - continue - } + if response.BifrostSpeechStreamResponse != nil { + lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency + } - if response.BifrostSpeechStreamResponse == nil { - streamErrors = append(streamErrors, "Stream response missing speech stream payload") - continue - } + if response.BifrostSpeechStreamResponse == nil { + streamErrors = append(streamErrors, "Stream response missing speech stream payload") + continue + } - if response.BifrostSpeechStreamResponse.Audio == nil { - streamErrors = append(streamErrors, "Stream response missing audio data") + if response.BifrostSpeechStreamResponse.Audio == nil { + streamErrors = append(streamErrors, "Stream response missing audio data") + continue + } + + // Log latency for each chunk (can be 0 for inter-chunks) + t.Logf("πŸ“Š Speech chunk %d latency: %d ms", chunkCount+1, response.BifrostSpeechStreamResponse.ExtraFields.Latency) + + // Collect audio chunks + if response.BifrostSpeechStreamResponse.Audio != nil { + chunkSize := len(response.BifrostSpeechStreamResponse.Audio) + if chunkSize == 0 { + t.Logf("⚠️ Skipping zero-length audio chunk") continue } + totalBytes += chunkSize + chunkCount++ + t.Logf("βœ… Received audio chunk %d: %d bytes", chunkCount, chunkSize) - // Collect audio chunks - if response.BifrostSpeechStreamResponse.Audio != nil { - chunkSize := len(response.BifrostSpeechStreamResponse.Audio) - if chunkSize == 0 { - t.Logf("⚠️ Skipping zero-length audio chunk") - continue - } - totalBytes += chunkSize - chunkCount++ - t.Logf("βœ… Received audio chunk %d: %d bytes", chunkCount, chunkSize) - - // Validate chunk structure - if response.BifrostSpeechStreamResponse.Type != "" && (response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDelta && response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDone) { - t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostSpeechStreamResponse.Type) - } - if response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) - } + // Validate chunk structure + if response.BifrostSpeechStreamResponse.Type != "" && (response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDelta && response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDone) { + t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostSpeechStreamResponse.Type) + } + if response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) } - - lastResponse = response - - case <-streamCtx.Done(): - streamErrors = append(streamErrors, "Stream reading timed out") - goto streamComplete } + + lastResponse = DeepCopyBifrostStream(response) } - streamComplete: // Enhanced validation of streaming results if len(streamErrors) > 0 { t.Logf("⚠️ Stream errors encountered: %v", streamErrors) @@ -203,6 +214,10 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con t.Logf("⚠️ Average chunk size seems small: %d bytes", averageChunkSize) } + if lastTokenLatency == 0 { + t.Errorf("❌ Last token latency is 0") + } + t.Logf("βœ… Streaming speech synthesis successful: %d chunks, %d total bytes for voice '%s' in %s format", chunkCount, totalBytes, tc.voice, tc.format) }) @@ -219,6 +234,15 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, t.Run("SpeechSynthesisStreamAdvanced", func(t *testing.T) { t.Run("LongText_HDModel_Streaming", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + if testConfig.Provider == schemas.Gemini { + t.Skipf("Skipping %s test", "LongText_HDModel_Streaming") + return + } + // Test streaming with HD model and very long text finalText := "" for i := 1; i <= 20; i++ { @@ -239,7 +263,7 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, ResponseFormat: "mp3", Instructions: "Speak at a natural pace with clear pronunciation.", }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.SpeechSynthesisFallbacks, } retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisStreamHD", testConfig) @@ -259,8 +283,10 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, }, } + requestCtx := context.Background() + responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.SpeechStreamRequest(ctx, request) + return client.SpeechStreamRequest(requestCtx, request) }) RequireNoError(t, err, "HD streaming speech synthesis failed") @@ -268,49 +294,39 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, var totalBytes int var chunkCount int var streamErrors []string + var lastTokenLatency int64 - streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second) - defer cancel() + for response := range responseChannel { + if response == nil { + streamErrors = append(streamErrors, "Received nil HD stream response") + continue + } - for { - select { - case response, ok := <-responseChannel: - if !ok { - goto hdStreamComplete - } + if response.BifrostError != nil { + streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError))) + continue + } - if response == nil { - streamErrors = append(streamErrors, "Received nil HD stream response") - continue - } + if response.BifrostSpeechStreamResponse != nil { + lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency + } - if response.BifrostError != nil { - streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError))) + if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.Audio != nil { + chunkSize := len(response.BifrostSpeechStreamResponse.Audio) + if chunkSize == 0 { + t.Logf("⚠️ Skipping zero-length HD audio chunk") continue } + totalBytes += chunkSize + chunkCount++ + t.Logf("βœ… HD chunk %d: %d bytes", chunkCount, chunkSize) + } - if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.Audio != nil { - chunkSize := len(response.BifrostSpeechStreamResponse.Audio) - if chunkSize == 0 { - t.Logf("⚠️ Skipping zero-length HD audio chunk") - continue - } - totalBytes += chunkSize - chunkCount++ - t.Logf("βœ… HD chunk %d: %d bytes", chunkCount, chunkSize) - } - - if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) - } - - case <-streamCtx.Done(): - streamErrors = append(streamErrors, "HD stream reading timed out") - goto hdStreamComplete + if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) } } - hdStreamComplete: if len(streamErrors) > 0 { t.Logf("⚠️ HD stream errors: %v", streamErrors) } @@ -323,16 +339,37 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, t.Fatalf("HD model should produce substantial audio data: got %d bytes, expected > 10000", totalBytes) } + if lastTokenLatency == 0 { + t.Errorf("❌ Last token latency is 0") + } + t.Logf("βœ… HD streaming successful: %d chunks, %d total bytes", chunkCount, totalBytes) }) t.Run("MultipleVoices_Streaming", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + voices := []string{} + // Test streaming with all available voices - voices := []string{"alloy", "echo", "fable", "onyx", "nova", "shimmer"} + openaiVoices := []string{"alloy", "echo", "fable", "onyx", "nova", "shimmer"} + geminiVoices := []string{"achernar", "achird", "algenib", "charon", "despina", "erinome"} testText := "Testing streaming speech synthesis with different voice options." + if testConfig.Provider == schemas.OpenAI { + voices = openaiVoices + } else if testConfig.Provider == schemas.Gemini { + voices = geminiVoices + } + for _, voice := range voices { t.Run("StreamingVoice_"+voice, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + voiceCopy := voice request := &schemas.BifrostSpeechRequest{ Provider: testConfig.Provider, @@ -346,7 +383,7 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, }, ResponseFormat: "mp3", }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.SpeechSynthesisFallbacks, } retryConfig := GetTestRetryConfigForScenario("SpeechSynthesisStreamVoice", testConfig) @@ -362,53 +399,48 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, }, } + requestCtx := context.Background() + responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.SpeechStreamRequest(ctx, request) + return client.SpeechStreamRequest(requestCtx, request) }) RequireNoError(t, err, fmt.Sprintf("Streaming failed for voice %s", voice)) var receivedData bool var streamErrors []string + var lastTokenLatency int64 - streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - for { - select { - case response, ok := <-responseChannel: - if !ok { - goto voiceStreamComplete - } - - if response == nil { - streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for voice %s", voice)) - continue - } - - if response.BifrostError != nil { - streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for voice %s: %s", voice, FormatErrorConcise(ParseBifrostError(response.BifrostError)))) - continue - } - - if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.Audio != nil && len(response.BifrostSpeechStreamResponse.Audio) > 0 { - receivedData = true - t.Logf("βœ… Received data for voice %s: %d bytes", voice, len(response.BifrostSpeechStreamResponse.Audio)) - } - - case <-streamCtx.Done(): - streamErrors = append(streamErrors, fmt.Sprintf("Stream timed out for voice %s", voice)) - goto voiceStreamComplete + for response := range responseChannel { + if response == nil { + streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for voice %s", voice)) + continue + } + + if response.BifrostError != nil { + streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for voice %s: %s", voice, FormatErrorConcise(ParseBifrostError(response.BifrostError)))) + continue + } + + if response.BifrostSpeechStreamResponse != nil { + lastTokenLatency = response.BifrostSpeechStreamResponse.ExtraFields.Latency + } + + if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.Audio != nil && len(response.BifrostSpeechStreamResponse.Audio) > 0 { + receivedData = true + t.Logf("βœ… Received data for voice %s: %d bytes", voice, len(response.BifrostSpeechStreamResponse.Audio)) } } - voiceStreamComplete: if len(streamErrors) > 0 { - t.Logf("⚠️ Stream errors for voice %s: %v", voice, streamErrors) + t.Errorf("❌ Stream errors for voice %s: %v", voice, streamErrors) } if !receivedData { - t.Fatalf("Should receive audio data for voice %s", voice) + t.Errorf("❌ Should receive audio data for voice %s", voice) + } + if lastTokenLatency == 0 { + t.Errorf("❌ Last token latency is 0") } t.Logf("βœ… Streaming successful for voice: %s", voice) }) diff --git a/tests/core-providers/scenarios/test_retry_framework.go b/tests/core-providers/scenarios/test_retry_framework.go index 59d9a41cf..90f8a5f8f 100644 --- a/tests/core-providers/scenarios/test_retry_framework.go +++ b/tests/core-providers/scenarios/test_retry_framework.go @@ -3,6 +3,7 @@ package scenarios import ( "fmt" "math" + "reflect" "strings" "testing" "time" @@ -13,6 +14,85 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +// DeepCopyBifrostStream creates a deep copy of a BifrostStream object to avoid pooling issues +func DeepCopyBifrostStream(original *schemas.BifrostStream) *schemas.BifrostStream { + if original == nil { + return nil + } + + // Use reflection to create a deep copy + return deepCopyReflect(original).(*schemas.BifrostStream) +} + +// deepCopyReflect performs a deep copy using reflection +func deepCopyReflect(original interface{}) interface{} { + if original == nil { + return nil + } + + originalValue := reflect.ValueOf(original) + return deepCopyValue(originalValue).Interface() +} + +// deepCopyValue recursively copies a reflect.Value +func deepCopyValue(original reflect.Value) reflect.Value { + switch original.Kind() { + case reflect.Ptr: + if original.IsNil() { + return reflect.Zero(original.Type()) + } + // Create a new pointer and recursively copy the value it points to + newPtr := reflect.New(original.Type().Elem()) + newPtr.Elem().Set(deepCopyValue(original.Elem())) + return newPtr + + case reflect.Struct: + // Create a new struct and copy each field + newStruct := reflect.New(original.Type()).Elem() + for i := 0; i < original.NumField(); i++ { + field := original.Field(i) + destField := newStruct.Field(i) + if destField.CanSet() { + destField.Set(deepCopyValue(field)) + } + } + return newStruct + + case reflect.Slice: + if original.IsNil() { + return reflect.Zero(original.Type()) + } + // Create a new slice and copy each element + newSlice := reflect.MakeSlice(original.Type(), original.Len(), original.Cap()) + for i := 0; i < original.Len(); i++ { + newSlice.Index(i).Set(deepCopyValue(original.Index(i))) + } + return newSlice + + case reflect.Map: + if original.IsNil() { + return reflect.Zero(original.Type()) + } + // Create a new map and copy each key-value pair + newMap := reflect.MakeMap(original.Type()) + for _, key := range original.MapKeys() { + newMap.SetMapIndex(deepCopyValue(key), deepCopyValue(original.MapIndex(key))) + } + return newMap + + case reflect.Interface: + if original.IsNil() { + return reflect.Zero(original.Type()) + } + // Copy the concrete value inside the interface + return deepCopyValue(original.Elem()) + + default: + // For basic types (int, string, bool, etc.), just return the value + return original + } +} + // TestRetryCondition defines an interface for checking if a test operation should be retried // This focuses specifically on LLM behavior inconsistencies, not HTTP errors (handled by Bifrost core) type TestRetryCondition interface { @@ -686,6 +766,18 @@ func DefaultTranscriptionRetryConfig() TestRetryConfig { } } +// ReasoningRetryConfig creates a retry config for reasoning tests +func ReasoningRetryConfig() TestRetryConfig { + return TestRetryConfig{ + MaxAttempts: 5, + BaseDelay: 750 * time.Millisecond, + MaxDelay: 8 * time.Second, + Conditions: []TestRetryCondition{ + &EmptyResponseCondition{}, + }, + } +} + // DefaultEmbeddingRetryConfig creates a retry config for embedding tests func DefaultEmbeddingRetryConfig() TestRetryConfig { return TestRetryConfig{ @@ -849,6 +941,8 @@ func GetTestRetryConfigForScenario(scenarioName string, testConfig config.Compre return SpeechStreamRetryConfig() case "Transcription", "TranscriptionStream": // πŸŽ™οΈ Transcription tests return DefaultTranscriptionRetryConfig() + case "Reasoning": + return ReasoningRetryConfig() default: // For basic scenarios like SimpleChat, TextCompletion return DefaultTestRetryConfig() diff --git a/tests/core-providers/scenarios/text_completion.go b/tests/core-providers/scenarios/text_completion.go index 4703ec75e..02dc7200b 100644 --- a/tests/core-providers/scenarios/text_completion.go +++ b/tests/core-providers/scenarios/text_completion.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -18,6 +19,10 @@ func RunTextCompletionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co } t.Run("TextCompletion", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + prompt := "In fruits, A is for apple and B is for" request := &schemas.BifrostTextCompletionRequest{ Provider: testConfig.Provider, @@ -25,7 +30,7 @@ func RunTextCompletionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co Input: &schemas.TextCompletionInput{ PromptStr: &prompt, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TextCompletionFallbacks, } // Use retry framework with enhanced validation diff --git a/tests/core-providers/scenarios/text_completion_stream.go b/tests/core-providers/scenarios/text_completion_stream.go index 727e88c0c..91f8d5155 100644 --- a/tests/core-providers/scenarios/text_completion_stream.go +++ b/tests/core-providers/scenarios/text_completion_stream.go @@ -2,6 +2,7 @@ package scenarios import ( "context" + "os" "strings" "testing" "time" @@ -20,6 +21,10 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont } t.Run("TextCompletionStream", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Create a text completion prompt prompt := "Write a short story about a robot learning to paint. Keep it under 150 words." @@ -40,7 +45,7 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont Params: &schemas.TextCompletionParameters{ MaxTokens: bifrost.Ptr(150), }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TextCompletionFallbacks, } // Use retry framework for stream requests @@ -92,7 +97,7 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont if response == nil { t.Fatal("Streaming response should not be nil") } - lastResponse = response + lastResponse = DeepCopyBifrostStream(response) // Basic validation of streaming response structure if response.BifrostTextCompletionResponse != nil { @@ -103,6 +108,9 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont t.Logf("⚠️ Warning: Response ID is empty") } + // Log latency for each chunk (can be 0 for inter-chunks) + t.Logf("πŸ“Š Chunk %d latency: %d ms", responseCount+1, response.BifrostTextCompletionResponse.ExtraFields.Latency) + // Validate text completion response structure if response.BifrostTextCompletionResponse.Choices == nil { t.Logf("⚠️ Warning: Choices should not be nil in text completion streaming") @@ -170,8 +178,17 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont t.Fatal("Final content should be substantial") } + // Validate latency is present in the last chunk (total latency) + if lastResponse != nil && lastResponse.BifrostTextCompletionResponse != nil { + if lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency <= 0 { + t.Errorf("❌ Last streaming chunk missing latency information (got %d ms)", lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency) + } else { + t.Logf("βœ… Total streaming latency: %d ms", lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency) + } + } + if !validationResult.Passed { - t.Logf("⚠️ Text completion streaming validation warnings: %v", validationResult.Errors) + t.Errorf("❌ Text completion streaming validation failed: %v", validationResult.Errors) } t.Logf("πŸ“Š Text completion streaming metrics: %d chunks, %d chars", responseCount, len(finalContent)) @@ -182,6 +199,10 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont // Test text completion streaming with different prompts t.Run("TextCompletionStreamVariedPrompts", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Use TextModel if available, otherwise fall back to ChatModel model := testConfig.TextModel if model == "" { @@ -211,6 +232,10 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont for _, testCase := range testPrompts { t.Run(testCase.name, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + input := &schemas.TextCompletionInput{ PromptStr: &testCase.prompt, } @@ -223,7 +248,7 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont MaxTokens: bifrost.Ptr(50), Temperature: bifrost.Ptr(0.7), }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TextCompletionFallbacks, } responseChannel, err := client.TextCompletionStreamRequest(ctx, request) @@ -291,6 +316,10 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont // Test text completion streaming with different parameters t.Run("TextCompletionStreamParameters", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Use TextModel if available, otherwise fall back to ChatModel model := testConfig.TextModel if model == "" { @@ -327,6 +356,10 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont for _, paramTest := range parameterTests { t.Run(paramTest.name, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + input := &schemas.TextCompletionInput{ PromptStr: &prompt, } @@ -340,7 +373,7 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont Temperature: paramTest.temperature, TopP: paramTest.topP, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TextCompletionFallbacks, } responseChannel, err := client.TextCompletionStreamRequest(ctx, request) @@ -414,6 +447,8 @@ func createConsolidatedTextCompletionResponse(finalContent string, lastResponse if len(lastResponse.BifrostTextCompletionResponse.Choices) > 0 && lastResponse.BifrostTextCompletionResponse.Choices[0].FinishReason != nil { consolidatedResponse.Choices[0].FinishReason = lastResponse.BifrostTextCompletionResponse.Choices[0].FinishReason } + + consolidatedResponse.ExtraFields = lastResponse.BifrostTextCompletionResponse.ExtraFields } return consolidatedResponse diff --git a/tests/core-providers/scenarios/tool_calls.go b/tests/core-providers/scenarios/tool_calls.go index f0d19e525..d873f8907 100644 --- a/tests/core-providers/scenarios/tool_calls.go +++ b/tests/core-providers/scenarios/tool_calls.go @@ -3,6 +3,7 @@ package scenarios import ( "context" "encoding/json" + "os" "strings" "testing" @@ -21,6 +22,10 @@ func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } t.Run("ToolCalls", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + chatMessages := []schemas.ChatMessage{ CreateBasicChatMessage("What's the weather like in New York? answer in celsius"), } @@ -150,4 +155,4 @@ func validateLocationInToolCalls(t *testing.T, toolCalls []ToolCallInfo, apiName } require.True(t, locationFound, "%s API tool call should specify New York as the location", apiName) -} \ No newline at end of file +} diff --git a/tests/core-providers/scenarios/transcription.go b/tests/core-providers/scenarios/transcription.go index 29e85a29b..18e099ea7 100644 --- a/tests/core-providers/scenarios/transcription.go +++ b/tests/core-providers/scenarios/transcription.go @@ -56,6 +56,10 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con for _, tc := range roundTripCases { t.Run(tc.name, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Step 1: Generate TTS audio voice := GetProviderVoice(testConfig.Provider, tc.voiceType) ttsRequest := &schemas.BifrostSpeechRequest{ @@ -70,7 +74,7 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con }, ResponseFormat: tc.format, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } ttsResponse, err := client.SpeechRequest(ctx, ttsRequest) @@ -104,7 +108,7 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con Format: bifrost.Ptr("mp3"), ResponseFormat: tc.responseFormat, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } // Enhanced validation for transcription @@ -152,6 +156,10 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con for _, tc := range customCases { t.Run(tc.name, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Use the utility function to generate audio audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, tc.text, "primary", "mp3") @@ -167,7 +175,7 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con Format: bifrost.Ptr("mp3"), ResponseFormat: tc.responseFormat, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } response, err := client.TranscriptionRequest(ctx, request) @@ -199,6 +207,10 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con for _, format := range formats { t.Run("Format_"+format, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + formatCopy := format request := &schemas.BifrostTranscriptionRequest{ Provider: testConfig.Provider, @@ -210,7 +222,7 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con Format: bifrost.Ptr("mp3"), ResponseFormat: &formatCopy, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } response, err := client.TranscriptionRequest(ctx, request) @@ -226,6 +238,10 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con }) t.Run("WithCustomParameters", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Generate audio for custom parameters test audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextMedium, "secondary", "mp3") @@ -242,7 +258,7 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con Prompt: bifrost.Ptr("This audio contains technical terminology and proper nouns."), ResponseFormat: bifrost.Ptr("json"), // Use json instead of verbose_json for whisper-1 }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } response, err := client.TranscriptionRequest(ctx, request) @@ -262,6 +278,10 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con for _, lang := range languages { t.Run("Language_"+lang, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + langCopy := lang request := &schemas.BifrostTranscriptionRequest{ Provider: testConfig.Provider, @@ -273,7 +293,7 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con Format: bifrost.Ptr("mp3"), Language: &langCopy, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } response, err := client.TranscriptionRequest(ctx, request) diff --git a/tests/core-providers/scenarios/transcription_stream.go b/tests/core-providers/scenarios/transcription_stream.go index 41ef288c4..cd96e7acd 100644 --- a/tests/core-providers/scenarios/transcription_stream.go +++ b/tests/core-providers/scenarios/transcription_stream.go @@ -23,6 +23,10 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte } t.Run("TranscriptionStream", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Generate TTS audio for streaming round-trip validation streamRoundTripCases := []struct { name string @@ -30,7 +34,6 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte voiceType string format string responseFormat *string - expectChunks int }{ { name: "StreamRoundTrip_Basic_MP3", @@ -38,7 +41,6 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte voiceType: "primary", format: "mp3", responseFormat: nil, // Default JSON streaming - expectChunks: 1, }, { name: "StreamRoundTrip_Medium_MP3", @@ -46,7 +48,6 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte voiceType: "secondary", format: "mp3", responseFormat: bifrost.Ptr("json"), - expectChunks: 1, }, { name: "StreamRoundTrip_Technical_MP3", @@ -54,12 +55,15 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte voiceType: "tertiary", format: "mp3", responseFormat: bifrost.Ptr("json"), - expectChunks: 1, }, } for _, tc := range streamRoundTripCases { t.Run(tc.name, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Step 1: Generate TTS audio voice := GetProviderVoice(testConfig.Provider, tc.voiceType) ttsRequest := &schemas.BifrostSpeechRequest{ @@ -74,7 +78,7 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte }, ResponseFormat: tc.format, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } ttsResponse, err := client.SpeechRequest(ctx, ttsRequest) @@ -110,7 +114,7 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte Format: bifrost.Ptr(tc.format), ResponseFormat: tc.responseFormat, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } // Use retry framework for streaming transcription @@ -121,7 +125,6 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte "transcribe_streaming_audio": true, "round_trip_test": true, "original_text": tc.text, - "min_chunks": tc.expectChunks, }, TestMetadata: map[string]interface{}{ "provider": testConfig.Provider, @@ -143,10 +146,10 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() - var fullTranscriptionText string - var chunkCount int - var lastResponse *schemas.BifrostStream - var streamErrors []string + fullTranscriptionText := "" + lastResponse := &schemas.BifrostStream{} + streamErrors := []string{} + lastTokenLatency := int64(0) // Read streaming chunks with enhanced validation for { @@ -173,39 +176,43 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte continue } + if response.BifrostTranscriptionStreamResponse != nil { + lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency + } + if response.BifrostTranscriptionStreamResponse.Text == "" && response.BifrostTranscriptionStreamResponse.Delta == nil { streamErrors = append(streamErrors, "Stream response missing transcription data") continue } + chunkIndex := response.BifrostTranscriptionStreamResponse.ExtraFields.ChunkIndex + + // Log latency for each chunk (can be 0 for inter-chunks) + t.Logf("πŸ“Š Transcription chunk %d latency: %d ms", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency) + // Collect transcription chunks transcribeData := response.BifrostTranscriptionStreamResponse if transcribeData.Text != "" { - chunkText := transcribeData.Text - - // Handle delta vs complete text chunks - if transcribeData.Delta != nil { - // This is a delta chunk - deltaText := *transcribeData.Delta - fullTranscriptionText += deltaText - t.Logf("βœ… Received transcription delta chunk %d: '%s'", chunkCount+1, deltaText) - } else { - // This is a complete text chunk - fullTranscriptionText += chunkText - t.Logf("βœ… Received transcription text chunk %d: '%s'", chunkCount+1, chunkText) - } - chunkCount++ + t.Logf("βœ… Received transcription text chunk %d with latency %d ms: '%s'", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency, transcribeData.Text) + } - // Validate chunk structure - if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta { - t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type) - } - if response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != testConfig.TranscriptionModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested) - } + // Handle delta vs complete text chunks + if transcribeData.Delta != nil { + // This is a delta chunk + deltaText := *transcribeData.Delta + fullTranscriptionText += deltaText + t.Logf("βœ… Received transcription delta chunk %d with latency %d ms: '%s'", chunkIndex, response.BifrostTranscriptionStreamResponse.ExtraFields.Latency, deltaText) } - lastResponse = response + // Validate chunk structure + if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta { + t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type) + } + if response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != testConfig.TranscriptionModel { + t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested) + } + + lastResponse = DeepCopyBifrostStream(response) case <-streamCtx.Done(): streamErrors = append(streamErrors, "Stream reading timed out") @@ -219,10 +226,6 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte t.Logf("⚠️ Stream errors encountered: %v", streamErrors) } - if chunkCount < tc.expectChunks { - t.Fatalf("Insufficient chunks received: got %d, expected at least %d", chunkCount, tc.expectChunks) - } - if lastResponse == nil { t.Fatal("Should have received at least one response") } @@ -231,6 +234,10 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte t.Fatal("Transcribed text should not be empty") } + if lastTokenLatency == 0 { + t.Errorf("❌ Last token latency is 0") + } + // Normalize for comparison (lowercase, remove punctuation) originalWords := strings.Fields(strings.ToLower(tc.text)) transcribedWords := strings.Fields(strings.ToLower(fullTranscriptionText)) @@ -284,9 +291,6 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte } t.Fatalf("Round-trip accuracy too low: got %d/%d words, need at least %d", foundWords, len(originalWords), minExpectedWords) } - - t.Logf("βœ… Stream round-trip successful: '%s' β†’ TTS β†’ SST β†’ '%s' (%d chunks, found %d/%d words)", - tc.text, fullTranscriptionText, chunkCount, foundWords, len(originalWords)) }) } }) @@ -301,6 +305,10 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c t.Run("TranscriptionStreamAdvanced", func(t *testing.T) { t.Run("JSONStreaming", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Generate audio for streaming test audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextBasic, "primary", "mp3") @@ -316,7 +324,7 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c Format: bifrost.Ptr("mp3"), ResponseFormat: bifrost.Ptr("json"), }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamJSON", testConfig) @@ -342,50 +350,35 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c var receivedResponse bool var streamErrors []string - streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - for { - select { - case response, ok := <-responseChannel: - if !ok { - goto verboseStreamComplete - } - - if response == nil { - streamErrors = append(streamErrors, "Received nil JSON stream response") - continue - } - - if response.BifrostError != nil { - streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError))) - continue - } + for response := range responseChannel { + if response == nil { + streamErrors = append(streamErrors, "Received nil JSON stream response") + continue + } - if response.BifrostTranscriptionStreamResponse != nil { - receivedResponse = true + if response.BifrostError != nil { + streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError))) + continue + } - // Check for JSON streaming specific fields - transcribeData := response.BifrostTranscriptionStreamResponse - if transcribeData.Type != "" { - t.Logf("βœ… Stream type: %v", transcribeData.Type) - if transcribeData.Delta != nil { - t.Logf("βœ… Delta: %s", *transcribeData.Delta) - } - } + if response.BifrostTranscriptionStreamResponse != nil { + receivedResponse = true - if transcribeData.Text != "" { - t.Logf("βœ… Received transcription text: %s", transcribeData.Text) + // Check for JSON streaming specific fields + transcribeData := response.BifrostTranscriptionStreamResponse + if transcribeData.Type != "" { + t.Logf("βœ… Stream type: %v", transcribeData.Type) + if transcribeData.Delta != nil { + t.Logf("βœ… Delta: %s", *transcribeData.Delta) } } - case <-streamCtx.Done(): - streamErrors = append(streamErrors, "JSON stream reading timed out") - goto verboseStreamComplete + if transcribeData.Text != "" { + t.Logf("βœ… Received transcription text: %s", transcribeData.Text) + } } } - verboseStreamComplete: if len(streamErrors) > 0 { t.Logf("⚠️ JSON stream errors: %v", streamErrors) } @@ -397,6 +390,10 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c }) t.Run("MultipleLanguages_Streaming", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Generate audio for language streaming tests audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextBasic, "primary", "mp3") // Test streaming with different language hints (only English for now) @@ -404,6 +401,10 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c for _, lang := range languages { t.Run("StreamLang_"+lang, func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + langCopy := lang request := &schemas.BifrostTranscriptionRequest{ Provider: testConfig.Provider, @@ -414,7 +415,7 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c Params: &schemas.TranscriptionParameters{ Language: &langCopy, }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamLang", testConfig) @@ -438,39 +439,28 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c var receivedData bool var streamErrors []string + var lastTokenLatency int64 - streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - for { - select { - case response, ok := <-responseChannel: - if !ok { - goto langStreamComplete - } - - if response == nil { - streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for language %s", lang)) - continue - } + for response := range responseChannel { + if response == nil { + streamErrors = append(streamErrors, fmt.Sprintf("Received nil stream response for language %s", lang)) + continue + } - if response.BifrostError != nil { - streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for language %s: %s", lang, FormatErrorConcise(ParseBifrostError(response.BifrostError)))) - continue - } + if response.BifrostError != nil { + streamErrors = append(streamErrors, fmt.Sprintf("Error in stream for language %s: %s", lang, FormatErrorConcise(ParseBifrostError(response.BifrostError)))) + continue + } + if response.BifrostTranscriptionStreamResponse != nil { + receivedData = true + t.Logf("βœ… Received transcription data for language %s", lang) if response.BifrostTranscriptionStreamResponse != nil { - receivedData = true - t.Logf("βœ… Received transcription data for language %s", lang) + lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency } - - case <-streamCtx.Done(): - streamErrors = append(streamErrors, fmt.Sprintf("Stream timed out for language %s", lang)) - goto langStreamComplete } } - langStreamComplete: if len(streamErrors) > 0 { t.Logf("⚠️ Stream errors for language %s: %v", lang, streamErrors) } @@ -478,12 +468,21 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c if !receivedData { t.Fatalf("Should receive transcription data for language %s", lang) } + + if lastTokenLatency == 0 { + t.Errorf("❌ Last token latency is 0") + } + t.Logf("βœ… Streaming successful for language: %s", lang) }) } }) t.Run("WithCustomPrompt_Streaming", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + // Generate audio for custom prompt streaming test audioData, _ := GenerateTTSAudioForTest(ctx, t, client, testConfig.Provider, testConfig.SpeechSynthesisModel, TTSTestTextTechnical, "tertiary", "mp3") @@ -498,7 +497,7 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c Language: bifrost.Ptr("en"), Prompt: bifrost.Ptr("This audio contains technical terms, proper nouns, and streaming-related vocabulary."), }, - Fallbacks: testConfig.Fallbacks, + Fallbacks: testConfig.TranscriptionFallbacks, } retryConfig := GetTestRetryConfigForScenario("TranscriptionStreamPrompt", testConfig) @@ -525,41 +524,31 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c var chunkCount int var streamErrors []string var receivedText string + var lastTokenLatency int64 - streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - for { - select { - case response, ok := <-responseChannel: - if !ok { - goto promptStreamComplete - } - - if response == nil { - streamErrors = append(streamErrors, "Received nil stream response with custom prompt") - continue - } + for response := range responseChannel { + if response == nil { + streamErrors = append(streamErrors, "Received nil stream response with custom prompt") + continue + } - if response.BifrostError != nil { - streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError))) - continue - } + if response.BifrostError != nil { + streamErrors = append(streamErrors, FormatErrorConcise(ParseBifrostError(response.BifrostError))) + continue + } - if response.BifrostTranscriptionStreamResponse != nil && response.BifrostTranscriptionStreamResponse.Text != "" { - chunkCount++ - chunkText := response.BifrostTranscriptionStreamResponse.Text - receivedText += chunkText - t.Logf("βœ… Custom prompt chunk %d: '%s'", chunkCount, chunkText) - } + if response.BifrostTranscriptionStreamResponse != nil { + lastTokenLatency = response.BifrostTranscriptionStreamResponse.ExtraFields.Latency + } - case <-streamCtx.Done(): - streamErrors = append(streamErrors, "Custom prompt stream reading timed out") - goto promptStreamComplete + if response.BifrostTranscriptionStreamResponse != nil && response.BifrostTranscriptionStreamResponse.Text != "" { + chunkCount++ + chunkText := response.BifrostTranscriptionStreamResponse.Text + receivedText += chunkText + t.Logf("βœ… Custom prompt chunk %d: '%s'", chunkCount, chunkText) } } - promptStreamComplete: if len(streamErrors) > 0 { t.Logf("⚠️ Custom prompt stream errors: %v", streamErrors) } @@ -574,6 +563,11 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c } else { t.Logf("⚠️ Custom prompt produced empty transcription") } + + if lastTokenLatency == 0 { + t.Errorf("❌ Last token latency is 0") + } + t.Logf("βœ… Custom prompt streaming successful: %d chunks received", chunkCount) }) }) diff --git a/tests/core-providers/scenarios/utils.go b/tests/core-providers/scenarios/utils.go index 6eca5e954..68bfc2e26 100644 --- a/tests/core-providers/scenarios/utils.go +++ b/tests/core-providers/scenarios/utils.go @@ -14,13 +14,13 @@ import ( // Shared test texts for TTS->SST round-trip validation const ( // Basic test text for simple round-trip validation - TTSTestTextBasic = "Hello, this is a test of speech synthesis from Bifrost." + TTSTestTextBasic = "Hello, this is a comprehensive test of speech synthesis capabilities from Bifrost AI Gateway. We are testing various aspects of text-to-speech conversion including clarity, pronunciation, and overall audio quality. This basic test should demonstrate the fundamental functionality of converting written text into natural-sounding speech audio." // Medium length text with punctuation for comprehensive testing - TTSTestTextMedium = "Testing speech synthesis and transcription round-trip. This text includes punctuation, numbers like 123, and technical terms." + TTSTestTextMedium = "Testing speech synthesis and transcription round-trip functionality with Bifrost AI Gateway. This comprehensive text includes various punctuation marks: commas, periods, exclamation points! Question marks? Semicolons; and colons: for thorough testing. We also include numbers like 123, 456.789, and technical terms such as API, HTTP, JSON, WebSocket, and machine learning algorithms. The system should handle abbreviations like Dr., Mr., Mrs., and acronyms like NASA, FBI, and CPU correctly. Additionally, we test special characters and symbols: @, #, $, %, &, *, +, =, and various currency symbols like €, Β£, Β₯." - // Short technical text for WAV format testing - TTSTestTextTechnical = "Bifrost AI gateway processes audio requests efficiently." + // Technical text for comprehensive format testing + TTSTestTextTechnical = "Bifrost AI Gateway is a sophisticated artificial intelligence proxy server that efficiently processes and routes audio requests, chat completions, embeddings, and various machine learning workloads across multiple provider endpoints. The system implements advanced load balancing algorithms, request queuing mechanisms, and intelligent failover strategies to ensure high availability and optimal performance. It supports multiple audio formats including MP3, WAV, FLAC, and OGG, with configurable bitrates, sample rates, and encoding parameters. The gateway handles authentication, rate limiting, request validation, response transformation, and comprehensive logging for enterprise-grade deployments. Performance metrics indicate sub-100ms latency for most operations with 99.9% uptime reliability." ) // GetProviderVoice returns an appropriate voice for the given provider diff --git a/tests/core-providers/scenarios/validation_presets.go b/tests/core-providers/scenarios/validation_presets.go index 18e99a7b0..3aaa8b88f 100644 --- a/tests/core-providers/scenarios/validation_presets.go +++ b/tests/core-providers/scenarios/validation_presets.go @@ -22,6 +22,7 @@ func BasicChatExpectations() ResponseExpectations { ShouldHaveUsageStats: true, ShouldHaveTimestamps: true, ShouldHaveModel: true, + ShouldHaveLatency: true, // Global expectation: latency should always be present ShouldNotContainWords: []string{ "i can't", "i cannot", "i'm unable", "i am unable", "i don't know", "i'm not sure", "i am not sure", @@ -110,6 +111,7 @@ func EmbeddingExpectations(expectedTexts []string) ResponseExpectations { ShouldHaveContent: false, // Embeddings don't have text content ExpectedChoiceCount: 0, // Embeddings use different structure ShouldHaveModel: true, + ShouldHaveLatency: true, // Global expectation: latency should always be present // Custom validation will be needed for embedding data ProviderSpecific: map[string]interface{}{ "expected_embedding_count": len(expectedTexts), @@ -128,8 +130,8 @@ func StreamingExpectations() ResponseExpectations { // ConversationExpectations returns validation expectations for multi-turn conversation scenarios func ConversationExpectations(contextKeywords []string) ResponseExpectations { expectations := BasicChatExpectations() - expectations.MinContentLength = 15 // Conversation responses should be more substantial - expectations.ShouldContainKeywords = contextKeywords // Should reference conversation context + expectations.MinContentLength = 15 // Conversation responses should be more substantial + expectations.ShouldContainAnyOf = contextKeywords // Should reference conversation context return expectations } @@ -159,6 +161,7 @@ func SpeechExpectations(minAudioBytes int) ResponseExpectations { ShouldHaveUsageStats: true, ShouldHaveTimestamps: true, ShouldHaveModel: true, + ShouldHaveLatency: true, // Global expectation: latency should always be present // Speech-specific validations stored in ProviderSpecific ProviderSpecific: map[string]interface{}{ "min_audio_bytes": minAudioBytes, @@ -177,6 +180,7 @@ func TranscriptionExpectations(minTextLength int) ResponseExpectations { ShouldHaveUsageStats: true, ShouldHaveTimestamps: true, ShouldHaveModel: true, + ShouldHaveLatency: true, // Global expectation: latency should always be present // Transcription-specific validations ShouldNotContainWords: []string{ "could not transcribe", "failed to process", @@ -198,14 +202,13 @@ func ReasoningExpectations() ResponseExpectations { ShouldHaveContent: true, MinContentLength: 50, // Reasoning requires substantial content MaxContentLength: 3000, // Reasoning can be very verbose - ExpectedChoiceCount: 1, // Usually expect one choice ShouldHaveUsageStats: true, ShouldHaveTimestamps: true, ShouldHaveModel: true, // Reasoning-specific validations ShouldContainAnyOf: []string{ "step", "first", "then", "next", "calculate", "therefore", "because", - "reasoning", "think", "analysis", "conclusion", "solution", + "reasoning", "think", "analysis", "conclusion", "solution", "solve", }, ShouldNotContainWords: []string{ "i can't", "i cannot", "i'm unable", "i am unable", @@ -300,10 +303,7 @@ func GetExpectationsForScenario(scenarioName string, testConfig config.Comprehen case "Reasoning": expectations := ReasoningExpectations() if requiresReasoning, ok := customParams["requires_reasoning"].(bool); ok && requiresReasoning { - expectations.ShouldContainAnyOf = []string{"step", "first", "then", "calculate", "therefore", "because"} - } - if isMathematical, ok := customParams["mathematical_problem"].(bool); ok && isMathematical { - expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, "calculate", "profit", "$") + expectations.ShouldContainAnyOf = []string{"step", "first", "then", "calculate", "therefore", "because", "solve"} } return expectations @@ -460,6 +460,9 @@ func CombineExpectations(expectations ...ResponseExpectations) ResponseExpectati if exp.ShouldHaveModel { base.ShouldHaveModel = exp.ShouldHaveModel } + if exp.ShouldHaveLatency { + base.ShouldHaveLatency = exp.ShouldHaveLatency + } // Merge provider specific data if len(exp.ProviderSpecific) > 0 { diff --git a/tests/core-providers/sgl_test.go b/tests/core-providers/sgl_test.go index 973098904..7cd9024cc 100644 --- a/tests/core-providers/sgl_test.go +++ b/tests/core-providers/sgl_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,12 +10,15 @@ import ( ) func TestSGL(t *testing.T) { + if os.Getenv("SGL_BASE_URL") == "" { + t.Skip("Skipping SGL tests because SGL_BASE_URL is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ Provider: schemas.SGL, @@ -39,5 +43,8 @@ func TestSGL(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("SGLTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/tests/core-providers/vertex_test.go b/tests/core-providers/vertex_test.go index 72223f406..6e611ff95 100644 --- a/tests/core-providers/vertex_test.go +++ b/tests/core-providers/vertex_test.go @@ -1,6 +1,7 @@ package tests import ( + "os" "testing" "github.com/maximhq/bifrost/tests/core-providers/config" @@ -9,12 +10,15 @@ import ( ) func TestVertex(t *testing.T) { + if os.Getenv("VERTEX_API_KEY") == "" && (os.Getenv("VERTEX_PROJECT_ID") == "" || os.Getenv("VERTEX_CREDENTIALS") == "") { + t.Skip("Skipping Vertex tests because VERTEX_API_KEY is not set and VERTEX_PROJECT_ID or VERTEX_CREDENTIALS is not set") + } + client, ctx, cancel, err := config.SetupTest() if err != nil { t.Fatalf("Error initializing test setup: %v", err) } defer cancel() - defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ Provider: schemas.Vertex, @@ -39,5 +43,8 @@ func TestVertex(t *testing.T) { }, } - runAllComprehensiveTests(t, client, ctx, testConfig) + t.Run("VertexTests", func(t *testing.T) { + runAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() } diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index 74484b327..86d0c7673 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -77,7 +77,7 @@ func TransportInterceptorMiddleware(config *lib.Config) lib.BifrostHTTPMiddlewar if len(bodyBytes) > 0 { if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { // If body is not valid JSON, log warning and continue without interception - logger.Warn(fmt.Sprintf("TransportInterceptor: Failed to unmarshal request body: %v", err)) + logger.Warn(fmt.Sprintf("TransportInterceptor: Failed to unmarshal request body: %v, skipping interceptor", err)) next(ctx) return } diff --git a/transports/changelog.md b/transports/changelog.md index 24cf7ac6f..59bf4c7e6 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -1,4 +1,13 @@ -- chore: Fixes form validation for Azure deployments. \ No newline at end of file +- chore: version update core to 1.2.13 and framework to 1.1.15 +- feat: added headers support for OTel configuration. Value prefixed with env will be fetched from environment variables (env.) +- feat: emission of OTel resource spans is completely async - this brings down inference overhead to < 1Β΅second +- fix: added latency calculation for vertex native requests +- feat: added cached tokens and reasoning tokens to the usage in ui +- fix: cost calculation for vertex requests +- feat: added global region support for vertex API +- fix: added filter for extra fields in chat completions request for Mistral provider +- fix: added wildcard validation for allowed origins in UI security settings +- fix: fixed code field in pending_safety_checks for Responses API \ No newline at end of file diff --git a/transports/version b/transports/version index d4c4950a3..0c00f6108 100644 --- a/transports/version +++ b/transports/version @@ -1 +1 @@ -1.3.9 +1.3.10 diff --git a/ui/app/config/views/securityView.tsx b/ui/app/config/views/securityView.tsx index bd226eb46..9a43f38af 100644 --- a/ui/app/config/views/securityView.tsx +++ b/ui/app/config/views/securityView.tsx @@ -68,7 +68,9 @@ export default function SecurityView() { const validation = validateOrigins(localConfig.allowed_origins); if (!validation.isValid && localConfig.allowed_origins.length > 0) { - toast.error(`Invalid origins: ${validation.invalidOrigins.join(", ")}. Origins must be valid URLs like https://example.com`); + toast.error( + `Invalid origins: ${validation.invalidOrigins.join(", ")}. Origins must be valid URLs like https://example.com, wildcard patterns like https://*.example.com, or "*" to allow all origins`, + ); return; } @@ -109,13 +111,14 @@ export default function SecurityView() {

Comma-separated list of allowed origins for CORS and WebSocket connections. Localhost origins are always allowed. Each - origin must be a complete URL with protocol (e.g., https://app.example.com, http://10.0.0.100:3000, https://*.example.com). + origin must be a complete URL with protocol (e.g., https://app.example.com, http://10.0.0.100:3000). Wildcards are supported + for subdomains (e.g., https://*.example.com) or use "*" to allow all origins.