diff --git a/core/bifrost.go b/core/bifrost.go index 0f9f4b62d..7600fe483 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -979,6 +979,8 @@ func (bifrost *Bifrost) createBaseProvider(providerKey schemas.ModelProvider, co return providers.NewOpenAIProvider(config, bifrost.logger), nil case schemas.Anthropic: return providers.NewAnthropicProvider(config, bifrost.logger), nil + case schemas.AnthropicPassthrough: + return providers.NewAnthropicPassthroughProvider(config, bifrost.logger), nil case schemas.Bedrock: return providers.NewBedrockProvider(config, bifrost.logger) case schemas.Cohere: @@ -1440,6 +1442,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex outputStream <- &schemas.BifrostStream{ BifrostResponse: processedResp, BifrostError: processedErr, + RawSSEEvent: streamMsg.RawSSEEvent, } } }() diff --git a/core/providers/anthropic_passthrough.go b/core/providers/anthropic_passthrough.go new file mode 100644 index 000000000..f5d5a49f1 --- /dev/null +++ b/core/providers/anthropic_passthrough.go @@ -0,0 +1,468 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Anthropic Passthrough provider implementation for OAuth mode. +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" + anthropic "github.com/maximhq/bifrost/core/schemas/providers/anthropic" + "github.com/valyala/fasthttp" +) + +// hopByHopHeaders are HTTP/1.1 headers that must not be forwarded by proxies. +var hopByHopHeaders = map[string]bool{ + "connection": true, + "proxy-connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "te": true, + "trailer": true, + "transfer-encoding": true, + "upgrade": true, +} + +// filterHeaders filters out hop-by-hop headers and returns only the allowed headers. +func filterHeaders(headers map[string][]string) map[string][]string { + filtered := make(map[string][]string, len(headers)) + for k, v := range headers { + if !hopByHopHeaders[strings.ToLower(k)] { + filtered[k] = v + } + } + return filtered +} + +// AnthropicPassthroughProvider implements OAuth passthrough mode for Anthropic's Claude API. +// This provider is used when the API key starts with "sk-ant-oat" (OAuth Access Token). +// It passes through the original request body and headers from Claude Code without modification. +type AnthropicPassthroughProvider struct { + logger schemas.Logger + client *fasthttp.Client // For non-streaming requests + streamClient *http.Client // For streaming requests + apiVersion string + networkConfig schemas.NetworkConfig + sendBackRawResponse bool + customProviderConfig *schemas.CustomProviderConfig +} + +// NewAnthropicPassthroughProvider creates a new Anthropic passthrough provider instance. +// It initializes the HTTP client with the provided configuration for OAuth passthrough mode. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewAnthropicPassthroughProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AnthropicPassthroughProvider { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, + } + + transport := &http.Transport{ + ResponseHeaderTimeout: time.Second * time.Duration( + max(config.NetworkConfig.DefaultRequestTimeoutInSeconds, 60), + ), + } + + // Configure proxy for streamClient transport + configureHTTPClientTransport(transport, config.ProxyConfig, logger) + + streamClient := &http.Client{ + Transport: transport, + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.anthropic.com" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &AnthropicPassthroughProvider{ + logger: logger, + client: client, + streamClient: streamClient, + apiVersion: "2023-06-01", + networkConfig: config.NetworkConfig, + sendBackRawResponse: config.SendBackRawResponse, + customProviderConfig: config.CustomProviderConfig, + } +} + +// GetProviderKey returns the provider identifier for Anthropic passthrough mode. +func (provider *AnthropicPassthroughProvider) GetProviderKey() schemas.ModelProvider { + return getProviderName(schemas.AnthropicPassthrough, provider.customProviderConfig) +} + +// extractRawBodyFromContext extracts the original request body from context. +// In passthrough mode, it retrieves the unmodified request body from Claude Code. +// Returns the raw body bytes or an error if extraction fails. +func (provider *AnthropicPassthroughProvider) extractRawBodyFromContext(ctx context.Context) ([]byte, *schemas.BifrostError) { + originalBody := ctx.Value(schemas.BifrostContextKeyOriginalRequest) + if originalBody == nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, fmt.Errorf("passthrough mode requires original request body in context"), provider.GetProviderKey()) + } + + rawBody, ok := originalBody.(json.RawMessage) + if !ok { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, fmt.Errorf("original request body has invalid type: %T, expected json.RawMessage", originalBody), provider.GetProviderKey()) + } + + if len(rawBody) == 0 { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, fmt.Errorf("original request body is empty"), provider.GetProviderKey()) + } + + return rawBody, nil +} + +// extractOriginalPathFromContext extracts the original request path from context. +// In passthrough mode, it retrieves the original path (e.g., /v1/messages/count_tokens?beta=true). +// Returns the path string or default path if not found in context. +func (provider *AnthropicPassthroughProvider) extractOriginalPathFromContext(ctx context.Context) string { + defaultPath := "/v1/messages?beta=true" + + if originalPath := ctx.Value(schemas.BifrostContextKeyOriginalPath); originalPath != nil { + if pathStr, ok := originalPath.(string); ok && pathStr != "" { + return pathStr + } + } + + return defaultPath +} + +// sendRequest sends a request to Anthropic's API in passthrough mode and handles the response using fasthttp. +// It passes through the original headers and body from Claude Code without modification. +// Returns: rawBytes (original response), decodedBytes (), headers, error +func (provider *AnthropicPassthroughProvider) sendRequest(ctx context.Context, rawBody []byte, url string, key string) ([]byte, []byte, http.Header, *schemas.BifrostError) { + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(url) + req.Header.SetMethod("POST") + req.SetBody(rawBody) + + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + if originalHeaders, ok := ctx.Value(schemas.BifrostContextKeyOriginalHeaders).(map[string][]string); ok { + for k, values := range filterHeaders(originalHeaders) { + for i, v := range values { + if i == 0 { + req.Header.Set(k, v) + } else { + req.Header.Add(k, v) + } + } + } + } + + // Send the request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, nil, nil, bifrostErr + } + + rawBytes := resp.Body() + + if resp.StatusCode() != fasthttp.StatusOK { + + var errorResp anthropic.AnthropicMessageError + + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Type = &errorResp.Error.Type + bifrostErr.Error.Message = errorResp.Error.Message + + return nil, nil, nil, bifrostErr + } + + decodedBody := rawBytes + contentEncoding := string(resp.Header.Peek("Content-Encoding")) + + if contentEncoding == "gzip" { + var err error + decodedBody, err = resp.BodyGunzip() + if err != nil { + return nil, nil, nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, fmt.Errorf("failed to decompress gzip response: %w", err), provider.GetProviderKey()) + } + } + + httpHeaders := make(http.Header) + for k, v := range resp.Header.All() { + httpHeaders.Add(string(k), string(v)) + } + + return rawBytes, decodedBody, httpHeaders, nil +} + +// TextCompletion is not supported by the Anthropic passthrough provider. +func (provider *AnthropicPassthroughProvider) TextCompletion(context.Context, schemas.Key, *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion in passthrough mode", "anthropic") +} + +// ChatCompletion performs a chat completion request to Anthropic's API in passthrough mode. +// It forwards the original request from Claude Code without modification, preserving all headers. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *AnthropicPassthroughProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + + if err := checkOperationAllowed(schemas.AnthropicPassthrough, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { + return nil, err + } + + rawBody, err := provider.extractRawBodyFromContext(ctx) + if err != nil { + return nil, err + } + + path := provider.extractOriginalPathFromContext(ctx) + url := provider.networkConfig.BaseURL + path + + rawBytes, decodedBody, respHeaders, bifrostErr := provider.sendRequest(ctx, rawBody, url, key.Value) + if bifrostErr != nil { + return nil, bifrostErr + } + + response := acquireAnthropicChatResponse() + defer releaseAnthropicChatResponse(response) + + _, bifrostErr = handleProviderResponse(decodedBody, response, false) + if bifrostErr != nil { + return nil, bifrostErr + } + + bifrostResponse := response.ToBifrostResponse() + + bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + RequestType: schemas.ChatCompletionRequest, + Provider: provider.GetProviderKey(), + ModelRequested: request.Model, + RawResponse: rawBytes, + RawHeaders: respHeaders, + } + + return bifrostResponse, nil +} + +// Embedding is not supported by the Anthropic passthrough provider. +func (provider *AnthropicPassthroughProvider) Embedding(context.Context, schemas.Key, *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "anthropic") +} + +// ChatCompletionStream performs a streaming chat completion request to the Anthropic API in passthrough mode. +// It supports real-time streaming of responses using Server-Sent Events (SSE) while preserving original headers. +// Returns a channel containing BifrostStream objects representing the stream or an error if the request fails. +func (provider *AnthropicPassthroughProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + + if err := checkOperationAllowed(schemas.AnthropicPassthrough, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { + return nil, err + } + + rawBody, err := provider.extractRawBodyFromContext(ctx) + if err != nil { + return nil, err + } + + headers, ok := ctx.Value(schemas.BifrostContextKeyOriginalHeaders).(map[string][]string) + if !ok { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, fmt.Errorf("passthrough mode requires original headers in context"), provider.GetProviderKey()) + } + + path := provider.extractOriginalPathFromContext(ctx) + url := provider.networkConfig.BaseURL + path + + return provider.handleAnthropicStreamingPassthrough( + ctx, + postHookRunner, + request, + url, + rawBody, + headers, + ) +} + +// Speech is not supported by the Anthropic passthrough provider. +func (provider *AnthropicPassthroughProvider) Speech(context.Context, schemas.Key, *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech", "anthropic_passthrough") +} + +// SpeechStream is not supported by the Anthropic passthrough provider. +func (provider *AnthropicPassthroughProvider) SpeechStream(context.Context, schemas.PostHookRunner, schemas.Key, *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("speech stream", "anthropic_passthrough") +} + +// Transcription is not supported by the Anthropic passthrough provider. +func (provider *AnthropicPassthroughProvider) Transcription(context.Context, schemas.Key, *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription", "anthropic_passthrough") +} + +// TranscriptionStream is not supported by the Anthropic passthrough provider. +func (provider *AnthropicPassthroughProvider) TranscriptionStream(context.Context, schemas.PostHookRunner, schemas.Key, *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("transcription stream", "anthropic_passthrough") +} + +// Responses is not supported by the Anthropic passthrough provider. +func (provider *AnthropicPassthroughProvider) Responses(context.Context, schemas.Key, *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses", "anthropic_passthrough") +} + +// ResponsesStream is not supported by the Anthropic passthrough provider. +func (provider *AnthropicPassthroughProvider) ResponsesStream(context.Context, schemas.PostHookRunner, schemas.Key, *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "anthropic_passthrough") +} + +// handleAnthropicStreamingPassthrough implements passthrough streaming. +// It reads raw SSE events from Anthropic and forwards them as-is without parsing/reconstruction. +// This preserves the exact stream format that Claude Code expects. +// Additionally, it parses SSE events to extract usage information for telemetry. +func (provider *AnthropicPassthroughProvider) handleAnthropicStreamingPassthrough( + ctx context.Context, + postHookRunner schemas.PostHookRunner, + request *schemas.BifrostChatRequest, + url string, + rawBody []byte, + headers map[string][]string, +) (chan *schemas.BifrostStream, *schemas.BifrostError) { + + // Create HTTP request for streaming - use rawBody as-is (TRUE passthrough) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(rawBody)) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, provider.GetProviderKey()) + } + + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + for k, values := range filterHeaders(headers) { + for i, v := range values { + if i == 0 { + req.Header.Set(k, v) + } else { + req.Header.Add(k, v) + } + } + } + + // Make the request using streaming client + resp, err := provider.streamClient.Do(req) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, provider.GetProviderKey()) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, newProviderAPIError(fmt.Sprintf("HTTP error from Anthropic: %d", resp.StatusCode), fmt.Errorf("%s", string(body)), resp.StatusCode, provider.GetProviderKey(), nil, nil) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine - forward raw SSE events and parse for telemetry + go func() { + defer close(responseChan) + defer resp.Body.Close() + + reader := bufio.NewReader(resp.Body) + // Track metadata for telemetry + var messageID string + var usage *schemas.LLMUsage + var finishReason *string + var eventType string + var eventData string + + for { + line, err := reader.ReadBytes('\n') + + // Send data if we have any + if len(line) > 0 { + // Forward raw SSE event as-is + select { + case responseChan <- &schemas.BifrostStream{ + RawSSEEvent: line, + }: + // Successfully sent + case <-ctx.Done(): + return + } + + // Parse SSE event for telemetry (parallel to forwarding) + lineStr := strings.TrimSpace(string(line)) + + // Skip empty lines and comments + if lineStr == "" || strings.HasPrefix(lineStr, ":") { + continue + } + + // Parse SSE event type + if strings.HasPrefix(lineStr, "event: ") { + eventType = strings.TrimSpace(strings.TrimPrefix(lineStr, "event: ")) + continue + } + + // Parse SSE event data + if strings.HasPrefix(lineStr, "data: ") { + eventData = strings.TrimSpace(strings.TrimPrefix(lineStr, "data: ")) + + // Only parse if we have both event type and data + if eventType != "" && eventData != "" { + var event anthropic.AnthropicStreamEvent + if err := sonic.Unmarshal([]byte(eventData), &event); err == nil { + // Extract usage information + if event.Usage != nil { + usage = &schemas.LLMUsage{ + PromptTokens: event.Usage.InputTokens, + CompletionTokens: event.Usage.OutputTokens, + TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens, + } + } + + // Extract finish reason from delta + if event.Delta != nil && event.Delta.StopReason != nil { + mapped := anthropic.MapAnthropicFinishReasonToBifrost(*event.Delta.StopReason) + finishReason = &mapped + } + + // Extract message ID from message_start event + if eventType == "message_start" && event.Message != nil && event.Message.ID != "" { + messageID = event.Message.ID + } + } + + // Reset event parsing state + eventType = "" + eventData = "" + } + } + } + + // Stop on any error (including EOF) + if err != nil { + if err == io.EOF { + // Send final chunk with usage for telemetry + response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, -1, schemas.ChatCompletionStreamRequest, provider.GetProviderKey(), request.Model) + handleStreamEndForPassthrough(ctx, response, postHookRunner, provider.logger) + } else { + // Stream error - log and send to client + provider.logger.Warn(fmt.Sprintf("Error reading Anthropic passthrough stream: %v", err)) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, provider.GetProviderKey(), request.Model, provider.logger) + } + return + } + } + }() + + return responseChan, nil +} diff --git a/core/providers/utils.go b/core/providers/utils.go index 3bb7f3cca..0486c06f8 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "net/textproto" "net/url" @@ -17,6 +18,7 @@ import ( schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpproxy" + "golang.org/x/net/proxy" ) // IMPORTANT: This function does NOT truly cancel the underlying fasthttp network request if the @@ -129,6 +131,136 @@ func configureProxy(client *fasthttp.Client, proxyConfig *schemas.ProxyConfig, l return client } +// configureHTTPClientProxy configures proxy settings for a standard http.Client. +// It returns a proxy function that can be assigned to http.Transport.Proxy. +// Returns nil if no proxy should be used or if the configuration is invalid. +// Note: For SOCKS5 proxies, use configureHTTPClientTransport instead as SOCKS5 +// requires DialContext configuration, not just the Proxy field. +func configureHTTPClientProxy(proxyConfig *schemas.ProxyConfig, logger schemas.Logger) func(*http.Request) (*url.URL, error) { + if proxyConfig == nil { + return nil + } + + switch proxyConfig.Type { + case schemas.NoProxy: + // No proxy - return nil to use direct connection + return nil + + case schemas.HTTPProxy: + if proxyConfig.URL == "" { + logger.Warn(fmt.Sprintf("Warning: %s proxy URL is required for setting up proxy", proxyConfig.Type)) + return nil + } + + // Parse the proxy URL + parsedURL, err := url.Parse(proxyConfig.URL) + if err != nil { + logger.Warn(fmt.Sprintf("Invalid proxy configuration: invalid %s proxy URL: %v", proxyConfig.Type, err)) + return nil + } + + // Add authentication if provided + if proxyConfig.Username != "" && proxyConfig.Password != "" { + parsedURL.User = url.UserPassword(proxyConfig.Username, proxyConfig.Password) + } + + // Return a proxy function that always returns this proxy URL + return http.ProxyURL(parsedURL) + + case schemas.Socks5Proxy: + // SOCKS5 requires DialContext configuration, not Proxy field + // Return nil and log a warning + logger.Warn("SOCKS5 proxy requires transport-level configuration; use configureHTTPClientTransport instead") + return nil + + case schemas.EnvProxy: + // Use environment variables for proxy configuration (http_proxy, https_proxy, no_proxy) + return http.ProxyFromEnvironment + + default: + logger.Warn(fmt.Sprintf("Invalid proxy configuration: unsupported proxy type: %s", proxyConfig.Type)) + return nil + } +} + +// configureHTTPClientTransport configures an http.Transport with proxy settings. +// This function handles all proxy types including SOCKS5 which requires DialContext. +// It modifies the transport in place and returns it for convenience. +func configureHTTPClientTransport(transport *http.Transport, proxyConfig *schemas.ProxyConfig, logger schemas.Logger) *http.Transport { + if proxyConfig == nil || transport == nil { + return transport + } + + switch proxyConfig.Type { + case schemas.NoProxy: + // No proxy - leave transport as is + return transport + + case schemas.HTTPProxy: + if proxyConfig.URL == "" { + logger.Warn("Warning: HTTP proxy URL is required for setting up proxy") + return transport + } + + // Parse the proxy URL + parsedURL, err := url.Parse(proxyConfig.URL) + if err != nil { + logger.Warn(fmt.Sprintf("Invalid proxy configuration: invalid HTTP proxy URL: %v", err)) + return transport + } + + // Add authentication if provided + if proxyConfig.Username != "" && proxyConfig.Password != "" { + parsedURL.User = url.UserPassword(proxyConfig.Username, proxyConfig.Password) + } + + // Set the proxy function + transport.Proxy = http.ProxyURL(parsedURL) + + case schemas.Socks5Proxy: + if proxyConfig.URL == "" { + logger.Warn("Warning: SOCKS5 proxy URL is required for setting up proxy") + return transport + } + + // Parse the proxy URL to extract host and port + parsedURL, err := url.Parse(proxyConfig.URL) + if err != nil { + logger.Warn(fmt.Sprintf("Invalid proxy configuration: invalid SOCKS5 proxy URL: %v", err)) + return transport + } + + // Create SOCKS5 dialer + var auth *proxy.Auth + if proxyConfig.Username != "" && proxyConfig.Password != "" { + auth = &proxy.Auth{ + User: proxyConfig.Username, + Password: proxyConfig.Password, + } + } + + dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct) + if err != nil { + logger.Warn(fmt.Sprintf("Failed to create SOCKS5 dialer: %v", err)) + return transport + } + + // Set custom DialContext that uses the SOCKS5 proxy + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + + case schemas.EnvProxy: + // Use environment variables for proxy configuration (http_proxy, https_proxy, no_proxy) + transport.Proxy = http.ProxyFromEnvironment + + default: + logger.Warn(fmt.Sprintf("Invalid proxy configuration: unsupported proxy type: %s", proxyConfig.Type)) + } + + return transport +} + // setExtraHeaders sets additional headers from NetworkConfig to the fasthttp request. // This allows users to configure custom headers for their provider requests. // Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates. @@ -518,6 +650,24 @@ func handleStreamEndWithSuccess( processAndSendResponse(ctx, postHookRunner, response, responseChan, logger) } +// handleStreamEndForPassthrough sets the stream end indicator and runs post-hooks +// WITHOUT sending the response to the channel. This is specifically for passthrough modes +// where raw SSE events are forwarded as-is, and we only need to trigger telemetry/logging +// without sending a duplicate event to the client. +func handleStreamEndForPassthrough( + ctx context.Context, + response *schemas.BifrostResponse, + postHookRunner schemas.PostHookRunner, + logger schemas.Logger, +) { + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + + _, bifrostErr := postHookRunner(&ctx, response, nil) + if bifrostErr != nil && handleStreamControlSkip(logger, bifrostErr) { + return + } +} + func handleStreamControlSkip(logger schemas.Logger, bifrostErr *schemas.BifrostError) bool { if bifrostErr == nil || bifrostErr.StreamControl == nil { return false diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 595cdfef5..2213cb7a6 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -42,20 +42,21 @@ type BifrostConfig struct { type ModelProvider string const ( - OpenAI ModelProvider = "openai" - Azure ModelProvider = "azure" - Anthropic ModelProvider = "anthropic" - Bedrock ModelProvider = "bedrock" - Cohere ModelProvider = "cohere" - Vertex ModelProvider = "vertex" - Mistral ModelProvider = "mistral" - Ollama ModelProvider = "ollama" - Groq ModelProvider = "groq" - SGL ModelProvider = "sgl" - Parasail ModelProvider = "parasail" - Cerebras ModelProvider = "cerebras" - Gemini ModelProvider = "gemini" - OpenRouter ModelProvider = "openrouter" + OpenAI ModelProvider = "openai" + Azure ModelProvider = "azure" + Anthropic ModelProvider = "anthropic" + AnthropicPassthrough ModelProvider = "anthropic_passthrough" + Bedrock ModelProvider = "bedrock" + Cohere ModelProvider = "cohere" + Vertex ModelProvider = "vertex" + Mistral ModelProvider = "mistral" + Ollama ModelProvider = "ollama" + Groq ModelProvider = "groq" + SGL ModelProvider = "sgl" + Parasail ModelProvider = "parasail" + Cerebras ModelProvider = "cerebras" + Gemini ModelProvider = "gemini" + OpenRouter ModelProvider = "openrouter" ) // SupportedBaseProviders is the list of base providers allowed for custom providers. @@ -112,6 +113,9 @@ const ( BifrostContextKeyVirtualKeyHeader BifrostContextKey = "x-bf-vk" BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" + BifrostContextKeyOriginalRequest BifrostContextKey = "bifrost-original-request" + BifrostContextKeyOriginalHeaders BifrostContextKey = "bifrost-original-headers" + BifrostContextKeyOriginalPath BifrostContextKey = "bifrost-original-path" ) // NOTE: for custom plugin implementation dealing with streaming short circuit, @@ -491,6 +495,7 @@ type BifrostResponseExtraFields struct { BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses RawResponse interface{} `json:"raw_response,omitempty"` + RawHeaders interface{} `json:"raw_headers,omitempty"` CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` } @@ -520,6 +525,7 @@ const ( type BifrostStream struct { *BifrostResponse *BifrostError + RawSSEEvent []byte } // BifrostError represents an error from the Bifrost system. diff --git a/core/utils.go b/core/utils.go index bfa1fcecd..cc93e660b 100644 --- a/core/utils.go +++ b/core/utils.go @@ -16,7 +16,7 @@ func Ptr[T any](v T) *T { // providerRequiresKey returns true if the given provider requires an API key for authentication. // Some providers like Ollama and SGL are keyless and don't require API keys. func providerRequiresKey(providerKey schemas.ModelProvider) bool { - return providerKey != schemas.Ollama && providerKey != schemas.SGL + return providerKey != schemas.Ollama && providerKey != schemas.SGL && providerKey != schemas.AnthropicPassthrough } // canProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index ca83b8073..fc84eaffb 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -856,13 +856,29 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { defer w.Flush() + sendDoneMarker := true + // Process streaming responses for response := range stream { if response == nil { continue } - // Extract and validate the response data + // Check for raw SSE event + if response.RawSSEEvent != nil { + sendDoneMarker = false + if _, err := w.Write(response.RawSSEEvent); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to write raw SSE event: %v", err)) + break + } + if err := w.Flush(); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to flush raw SSE event: %v", err)) + break + } + continue + } + + // Standard mode: Extract and validate the response data data, valid := extractResponse(response) if !valid { continue @@ -889,8 +905,10 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge } // Send the [DONE] marker to indicate the end of the stream - if _, err := fmt.Fprint(w, "data: [DONE]\n\n"); err != nil { - h.logger.Warn(fmt.Sprintf("Failed to write SSE done marker: %v", err)) + if sendDoneMarker { + if _, err := fmt.Fprint(w, "data: [DONE]\n\n"); err != nil { + h.logger.Warn(fmt.Sprintf("Failed to write SSE done marker: %v", err)) + } } }) } diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index f18d06bac..858ae104a 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -16,22 +16,30 @@ type AnthropicRouter struct { // CreateAnthropicRouteConfigs creates route configurations for Anthropic endpoints. func CreateAnthropicRouteConfigs(pathPrefix string) []RouteConfig { - return []RouteConfig{ - { - Path: pathPrefix + "/v1/messages", + createConfig := func(path string) RouteConfig { + return RouteConfig{ + Path: path, Method: "POST", GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicMessageRequest{} }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { + chatReq := anthropicReq.ToBifrostRequest() return &schemas.BifrostRequest{ - ChatRequest: anthropicReq.ToBifrostRequest(), + Provider: chatReq.Provider, + Model: chatReq.Model, + ChatRequest: chatReq, }, nil } return nil, errors.New("invalid request type") }, ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + if resp.ExtraFields.RawResponse != nil { + if rawBytes, ok := resp.ExtraFields.RawResponse.([]byte); ok { + return rawBytes, nil + } + } return anthropic.ToAnthropicChatCompletionResponse(resp), nil }, ErrorConverter: func(err *schemas.BifrostError) interface{} { @@ -45,7 +53,12 @@ func CreateAnthropicRouteConfigs(pathPrefix string) []RouteConfig { return anthropic.ToAnthropicChatCompletionStreamError(err) }, }, - }, + } + } + + return []RouteConfig{ + createConfig(pathPrefix + "/v1/messages"), + createConfig(pathPrefix + "/v1/messages/{path:*}"), } } diff --git a/transports/bifrost-http/integrations/anthropic/types.go b/transports/bifrost-http/integrations/anthropic/types.go new file mode 100644 index 000000000..e69de29bb diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go index ea4136474..d3cf25e97 100644 --- a/transports/bifrost-http/integrations/utils.go +++ b/transports/bifrost-http/integrations/utils.go @@ -48,16 +48,16 @@ package integrations import ( + "bufio" "context" "encoding/json" "fmt" "log" + "net/http" "reflect" "strconv" "strings" - "bufio" - "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" @@ -215,6 +215,45 @@ func (g *GenericRouter) RegisterRoutes(r *router.Router) { } } +// isAPIKeyAuth checks if the request uses standard API key authentication. +// Returns true for API key auth (x-api-key header), false for OAuth (Bearer sk-ant-oat*). +// This is required for Claude Code specifically, which may use OAuth authentication. +// Default behavior is to assume API mode when neither x-api-key nor OAuth token is present. +func isAPIKeyAuth(ctx *fasthttp.RequestCtx) bool { + // If x-api-key header is present - this is definitely API mode + if apiKey := string(ctx.Request.Header.Peek("x-api-key")); apiKey != "" { + return true + } + + // Check for OAuth token in Authorization header + if authHeader := string(ctx.Request.Header.Peek("Authorization")); authHeader != "" { + if strings.HasPrefix(strings.ToLower(authHeader), "bearer sk-ant-oat") { + return false // OAuth mode, NOT API + } + } + + // Default to API mode + return true +} + +// isAnthropicRequest checks if the request is targeting an Anthropic endpoint. +// This is used to determine if we need to preserve the raw request for potential passthrough. +func isAnthropicRequest(ctx *fasthttp.RequestCtx) bool { + path := string(ctx.Path()) + return strings.Contains(path, "/anthropic/") || strings.Contains(path, "/v1/messages") +} + +// extractHeaders converts fasthttp headers to a map for passthrough mode. +// This preserves all original headers for the anthropic_passthrough provider. +func extractHeaders(ctx *fasthttp.RequestCtx) map[string][]string { + headers := make(map[string][]string) + for key, value := range ctx.Request.Header.All() { + keyStr := string(key) + headers[keyStr] = append(headers[keyStr], string(value)) + } + return headers +} + // createHandler creates a fasthttp handler for the given route configuration. // The handler follows this flow: // 1. Parse JSON request body into the configured request type (for methods that expect bodies) @@ -231,8 +270,33 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle method := string(ctx.Method()) + var body []byte + var originalHeaders map[string][]string + var originalPath string + // Parse request body based on configuration if method != fasthttp.MethodGet && method != fasthttp.MethodDelete { + + body = ctx.Request.Body() + + // If this is an Anthropic request with OAuth, preserve headers and path for passthrough + if isAnthropicRequest(ctx) && !isAPIKeyAuth(ctx) { + originalHeaders = extractHeaders(ctx) + // Extract the original path, removing the /anthropic prefix + fullPath := string(ctx.Path()) + + if strings.HasPrefix(fullPath, "/anthropic") { + originalPath = strings.TrimPrefix(fullPath, "/anthropic") + } else { + originalPath = fullPath + } + + // Preserve query string if present + if queryString := string(ctx.URI().QueryString()); queryString != "" { + originalPath += "?" + queryString + } + } + if config.RequestParser != nil { // Use custom parser (e.g., for multipart/form-data) if err := config.RequestParser(ctx, req); err != nil { @@ -241,7 +305,6 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle } } else { // Use default JSON parsing - body := ctx.Request.Body() if len(body) > 0 { if err := json.Unmarshal(body, req); err != nil { g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "Invalid JSON")) @@ -277,6 +340,17 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to parse fallbacks: "+err.Error())) return } + // Auto-detect auth mode for Anthropic: API keys use standard provider, + // OAuth tokens (sk-ant-oat-*) use passthrough to preserve request structure + if bifrostReq.Provider == schemas.Anthropic { + // Switch to passthrough provider if NOT API key auth (i.e., OAuth) + if !isAPIKeyAuth(ctx) { + bifrostReq.Provider = schemas.AnthropicPassthrough + if bifrostReq.ChatRequest != nil { + bifrostReq.ChatRequest.Provider = schemas.AnthropicPassthrough + } + } + } // Check if streaming is requested isStreaming := false @@ -294,6 +368,17 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle } } + if bifrostReq.Provider == schemas.AnthropicPassthrough && len(body) > 0 { + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyOriginalRequest, json.RawMessage(body)) + if originalHeaders != nil { + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyOriginalHeaders, originalHeaders) + } + + if originalPath != "" { + *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyOriginalPath, originalPath) + } + } + if isStreaming { g.handleStreamingRequest(ctx, config, bifrostReq, bifrostCtx) } else { @@ -342,6 +427,11 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf return } + var headers http.Header + if h, ok := result.ExtraFields.RawHeaders.(http.Header); ok { + headers = h + } + // Convert Bifrost response to integration-specific format and send response, err := config.ResponseConverter(result) if err != nil { @@ -360,7 +450,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf } } - g.sendSuccess(ctx, config.ErrorConverter, response) + g.sendSuccess(ctx, config.ErrorConverter, response, headers) } // handleStreamingRequest handles streaming requests using Server-Sent Events (SSE) @@ -466,6 +556,16 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, config RouteCo default: } + if response.RawSSEEvent != nil { + if _, err := w.Write(response.RawSSEEvent); err != nil { + return + } + if err := w.Flush(); err != nil { + return + } + continue + } + // Handle errors if response.BifrostError != nil { var errorResponse interface{} @@ -625,8 +725,21 @@ func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, errorConverter Error } // sendSuccess sends a successful response with HTTP 200 status and JSON body. -func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, errorConverter ErrorConverter, response interface{}) { +func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, errorConverter ErrorConverter, response interface{}, headers http.Header) { ctx.SetStatusCode(fasthttp.StatusOK) + + if headers != nil { + if rawBytes, okBody := response.([]byte); okBody { + for k, vv := range headers { + for _, v := range vv { + ctx.Response.Header.Add(k, v) + } + } + ctx.SetBody(rawBytes) + return + } + } + ctx.SetContentType("application/json") responseBody, err := json.Marshal(response)