diff --git a/core/bifrost.go b/core/bifrost.go index e97474319..fd94a53dc 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -1096,22 +1096,26 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha func (bifrost *Bifrost) shouldTryFallbacks(req *schemas.BifrostRequest, primaryErr *schemas.BifrostError) bool { // If no primary error, we succeeded if primaryErr == nil { + bifrost.logger.Debug("No primary error, we should not try fallbacks") return false } // Handle request cancellation if primaryErr.Error != nil && primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { + bifrost.logger.Debug("Request cancelled, we should not try fallbacks") return false } // Check if this is a short-circuit error that doesn't allow fallbacks // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { + bifrost.logger.Debug("AllowFallbacks is false, we should not try fallbacks") return false } // If no fallbacks configured, return primary error if len(req.Fallbacks) == 0 { + bifrost.logger.Debug("No fallbacks configured, we should not try fallbacks") return false } @@ -1201,6 +1205,8 @@ func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, f // If the primary provider fails, it will try each fallback provider in order until one succeeds. // It is the wrapper for all non-streaming public API methods. func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + defer bifrost.releaseBifrostRequest(req) + if err := validateRequest(req); err != nil { err.ExtraFields = schemas.BifrostErrorExtraFields{ Provider: req.Provider, @@ -1215,9 +1221,18 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR ctx = bifrost.ctx } + bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s and %d fallbacks", req.Provider, req.Model, len(req.Fallbacks))) + // Try the primary provider first primaryResult, primaryErr := bifrost.tryRequest(req, ctx) + if primaryErr != nil { + bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", req.Provider, req.Model, primaryErr)) + if len(req.Fallbacks) > 0 { + bifrost.logger.Debug(fmt.Sprintf("Check if we should try %d fallbacks", len(req.Fallbacks))) + } + } + // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { @@ -1226,10 +1241,12 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR // Try fallbacks in order for _, fallback := range req.Fallbacks { + bifrost.logger.Debug(fmt.Sprintf("Trying fallback provider %s with model %s", fallback.Provider, fallback.Model)) ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) fallbackReq := bifrost.prepareFallbackRequest(req, fallback) if fallbackReq == nil { + bifrost.logger.Debug(fmt.Sprintf("Fallback provider %s with model %s is nil", fallback.Provider, fallback.Model)) continue } @@ -1255,6 +1272,8 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR // If the primary provider fails, it will try each fallback provider in order until one succeeds. // It is the wrapper for all streaming public API methods. func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + defer bifrost.releaseBifrostRequest(req) + if err := validateRequest(req); err != nil { err.ExtraFields = schemas.BifrostErrorExtraFields{ Provider: req.Provider, @@ -1867,8 +1886,7 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { bifrost.responseStreamPool.Put(msg.ResponseStream) } - // Reset and return BifrostRequest to pool - bifrost.releaseBifrostRequest(&msg.BifrostRequest) + // Release of Bifrost Request is handled in handle methods as they are required for fallbacks // Clear references and return to pool msg.Response = nil diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index 93f31917b..8a053a7ad 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -18,11 +18,13 @@ type PluginShortCircuit struct { // PreHooks are executed in the order they are registered. // PostHooks are executed in the reverse order of PreHooks. // -// PreHooks and PostHooks can be used to implement custom logic, such as: -// - Rate limiting -// - Caching -// - Logging -// - Monitoring +// Execution order: +// 1. TransportInterceptor (HTTP transport only, modifies raw headers/body before entering Bifrost core) +// 2. PreHook (executed in registration order) +// 3. Provider call +// 4. PostHook (executed in reverse order of PreHooks) +// +// Common use cases: rate limiting, caching, logging, monitoring, request transformation, governance. // // Plugin error handling: // - No Plugin errors are returned to the caller; they are logged as warnings by the Bifrost instance. @@ -44,6 +46,12 @@ type Plugin interface { // GetName returns the name of the plugin. GetName() string + // TransportInterceptor is called at the HTTP transport layer before requests enter Bifrost core. + // It allows plugins to modify raw HTTP headers and body before transformation into BifrostRequest. + // Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly. + // Returns modified headers, modified body, and any error that occurred during interception. + TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) + // PreHook is called before a request is processed by a provider. // It allows plugins to modify the request before it is sent to the provider. // The context parameter can be used to maintain state across plugin calls. diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 3e7872fa6..5d2fdff6d 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -4,6 +4,10 @@ package governance import ( "context" "fmt" + "math/rand/v2" + "slices" + "sort" + "strings" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" @@ -28,6 +32,10 @@ type Config struct { IsVkMandatory *bool `json:"is_vk_mandatory"` } +type InMemoryStore interface { + GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig +} + // GovernancePlugin implements the main governance plugin with hierarchical budget system type GovernancePlugin struct { ctx context.Context @@ -43,12 +51,54 @@ type GovernancePlugin struct { pricingManager *pricing.PricingManager logger schemas.Logger + // Transport dependencies + inMemoryStore InMemoryStore + isVkMandatory *bool } -// Init creates a new governance plugin with cleanly segregated components -// All governance features are enabled by default with optimized settings -func Init(ctx context.Context, config *Config, logger schemas.Logger, store configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, pricingManager *pricing.PricingManager) (*GovernancePlugin, error) { +// Init initializes and returns a governance plugin instance. +// +// It wires the core components (store, resolver, tracker), performs a best-effort +// startup reset of expired limits when a persistent `configstore.ConfigStore` is +// provided, and establishes a cancellable plugin context used by background work. +// +// Behavior and defaults: +// - Enables all governance features with optimized defaults. +// - If `store` is nil, the plugin runs in-memory only (no persistence). +// - If `pricingManager` is nil, cost calculation is skipped. +// - `config.IsVkMandatory` controls whether `x-bf-vk` is required in PreHook. +// - `inMemoryStore` is used by TransportInterceptor to validate configured providers +// and build provider-prefixed models; it may be nil. When nil, transport-level +// provider validation/routing is skipped and existing model strings are left +// unchanged. This is safe and recommended when using the plugin directly from +// the Go SDK without the HTTP transport. +// +// Parameters: +// - ctx: base context for the plugin; a child context with cancel is created. +// - config: plugin flags; may be nil. +// - logger: logger used by all subcomponents. +// - store: configuration store used for persistence; may be nil. +// - governanceConfig: initial/seed governance configuration for the store. +// - pricingManager: optional pricing manager to compute request cost. +// - inMemoryStore: provider registry used for routing/validation in transports. +// +// Returns: +// - *GovernancePlugin on success. +// - error if the governance store fails to initialize. +// +// Side effects: +// - Logs warnings when optional dependencies are missing. +// - May perform startup resets via the usage tracker when `store` is non-nil. +func Init( + ctx context.Context, + config *Config, + logger schemas.Logger, + store configstore.ConfigStore, + governanceConfig *configstore.GovernanceConfig, + pricingManager *pricing.PricingManager, + inMemoryStore InMemoryStore, +) (*GovernancePlugin, error) { if store == nil { logger.Warn("governance plugin requires config store to persist data, running in memory only mode") } @@ -56,6 +106,12 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, store conf logger.Warn("governance plugin requires pricing manager to calculate cost, all cost calculations will be skipped.") } + // Handle nil config - use safe default for IsVkMandatory + var isVkMandatory *bool + if config != nil { + isVkMandatory = config.IsVkMandatory + } + governanceStore, err := NewGovernanceStore(ctx, logger, store, governanceConfig) if err != nil { return nil, fmt.Errorf("failed to initialize governance store: %w", err) @@ -84,7 +140,8 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, store conf configStore: store, pricingManager: pricingManager, logger: logger, - isVkMandatory: config.IsVkMandatory, + isVkMandatory: isVkMandatory, + inMemoryStore: inMemoryStore, } return plugin, nil @@ -95,6 +152,112 @@ func (p *GovernancePlugin) GetName() string { return PluginName } +// TransportInterceptor intercepts requests before they are processed (governance decision point) +func (p *GovernancePlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + var virtualKeyValue string + + for header, value := range headers { + if strings.ToLower(string(header)) == "x-bf-vk" { + virtualKeyValue = string(value) + break + } + } + if virtualKeyValue == "" { + return headers, body, nil + } + + // Check if the request has a model field + modelValue, hasModel := body["model"] + if !hasModel { + return headers, body, nil + } + modelStr, ok := modelValue.(string) + if !ok || modelStr == "" { + return headers, body, nil + } + + // Check if model already has provider prefix (contains "/") + if strings.Contains(modelStr, "/") { + provider, _ := schemas.ParseModelString(modelStr, "") + // Checking valid provider when store is available; if store is nil, + // assume the prefixed model should be left unchanged. + if p.inMemoryStore != nil { + if _, ok := p.inMemoryStore.GetConfiguredProviders()[provider]; ok { + return headers, body, nil + } + } else { + return headers, body, nil + } + } + + virtualKey, ok := p.store.GetVirtualKey(virtualKeyValue) + if !ok || virtualKey == nil || !virtualKey.IsActive { + return headers, body, nil + } + + // Get provider configs for this virtual key + providerConfigs := virtualKey.ProviderConfigs + if len(providerConfigs) == 0 { + // No provider configs, continue without modification + return headers, body, nil + } + allowedProviderConfigs := make([]configstore.TableVirtualKeyProviderConfig, 0) + for _, config := range providerConfigs { + if len(config.AllowedModels) == 0 || slices.Contains(config.AllowedModels, modelStr) { + allowedProviderConfigs = append(allowedProviderConfigs, config) + } + } + if len(allowedProviderConfigs) == 0 { + // No allowed provider configs, continue without modification + return headers, body, nil + } + // Weighted random selection from allowed providers for the main model + totalWeight := 0.0 + for _, config := range allowedProviderConfigs { + totalWeight += config.Weight + } + // Generate random number between 0 and totalWeight + randomValue := rand.Float64() * totalWeight + // Select provider based on weighted random selection + var selectedProvider schemas.ModelProvider + currentWeight := 0.0 + for _, config := range allowedProviderConfigs { + currentWeight += config.Weight + if randomValue <= currentWeight { + selectedProvider = schemas.ModelProvider(config.Provider) + break + } + } + // Fallback: if no provider was selected (shouldn't happen but guard against FP issues) + if selectedProvider == "" && len(allowedProviderConfigs) > 0 { + selectedProvider = schemas.ModelProvider(allowedProviderConfigs[0].Provider) + } + // Update the model field in the request body + body["model"] = string(selectedProvider) + "/" + modelStr + + // Check if fallbacks field is already present + _, hasFallbacks := body["fallbacks"] + if !hasFallbacks && len(allowedProviderConfigs) > 1 { + // Sort allowed provider configs by weight (descending) + sort.Slice(allowedProviderConfigs, func(i, j int) bool { + return allowedProviderConfigs[i].Weight > allowedProviderConfigs[j].Weight + }) + + // Filter out the selected provider and create fallbacks array + fallbacks := make([]string, 0, len(allowedProviderConfigs)-1) + for _, config := range allowedProviderConfigs { + if config.Provider != string(selectedProvider) { + fallbacks = append(fallbacks, string(schemas.ModelProvider(config.Provider))+"/"+modelStr) + } + } + + // Add fallbacks to request body + body["fallbacks"] = fallbacks + } + + return headers, body, nil +} + // PreHook intercepts requests before they are processed (governance decision point) func (p *GovernancePlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { // Extract governance headers and virtual key using utility functions diff --git a/plugins/jsonparser/main.go b/plugins/jsonparser/main.go index a93e3160f..b80b17330 100644 --- a/plugins/jsonparser/main.go +++ b/plugins/jsonparser/main.go @@ -87,6 +87,11 @@ func (p *JsonParserPlugin) GetName() string { return PluginName } +// TransportInterceptor is not used for this plugin +func (p *JsonParserPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return headers, body, nil +} + // PreHook is not used for this plugin as we only process responses func (p *JsonParserPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { return req, nil, nil diff --git a/plugins/logging/main.go b/plugins/logging/main.go index e4fddb5ba..6debdd029 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -211,6 +211,11 @@ func (p *LoggerPlugin) GetName() string { return PluginName } +// TransportInterceptor is not used for this plugin +func (p *LoggerPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return headers, body, nil +} + // PreHook is called before a request is processed - FULLY ASYNC, NO DATABASE I/O func (p *LoggerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { if ctx == nil { diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index cb991b6c2..c96f1f9f5 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -117,6 +117,11 @@ func (plugin *Plugin) GetName() string { return PluginName } +// TransportInterceptor is not used for this plugin +func (plugin *Plugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return headers, body, nil +} + // getEffectiveLogRepoID determines which single log repo ID to use based on priority: // 1. Header log repo ID (if provided) // 2. Default log repo ID from config (if configured) diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go index 57d44b4b3..80c2a73e9 100644 --- a/plugins/mocker/main.go +++ b/plugins/mocker/main.go @@ -478,6 +478,11 @@ func (p *MockerPlugin) GetName() string { return PluginName } +// TransportInterceptor is not used for this plugin +func (p *MockerPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return headers, body, nil +} + // PreHook intercepts requests and applies mocking rules based on configuration // This is called before the actual provider request and can short-circuit the flow func (p *MockerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { diff --git a/plugins/otel/main.go b/plugins/otel/main.go index 52aab0217..9097124fc 100644 --- a/plugins/otel/main.go +++ b/plugins/otel/main.go @@ -110,6 +110,11 @@ func (p *OtelPlugin) GetName() string { return PluginName } +// TransportInterceptor is not used for this plugin +func (p *OtelPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return headers, body, nil +} + // ValidateConfig function for the OTEL plugin func (p *OtelPlugin) ValidateConfig(config any) (*Config, error) { var otelConfig Config diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index bf7ac1909..a5151f51a 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -335,6 +335,11 @@ func (plugin *Plugin) GetName() string { return PluginName } +// TransportInterceptor is not used for this plugin +func (plugin *Plugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return headers, body, nil +} + // PreHook is called before a request is processed by Bifrost. // It performs a two-stage cache lookup: first direct hash matching, then semantic similarity search. // Uses UUID-based keys for entries stored in the VectorStore. diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go index 030922aac..828d2655f 100644 --- a/plugins/telemetry/main.go +++ b/plugins/telemetry/main.go @@ -68,6 +68,11 @@ func (p *PrometheusPlugin) GetName() string { return PluginName } +// TransportInterceptor is not used for this plugin +func (p *PrometheusPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + return headers, body, nil +} + // PreHook records the start time of the request in the context. // This time is used later in PostHook to calculate request duration. func (p *PrometheusPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index 8236cd5e2..74484b327 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -3,13 +3,7 @@ package handlers import ( "encoding/json" "fmt" - "math/rand" - "slices" - "sort" - "strings" - "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" @@ -43,153 +37,89 @@ func CorsMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { } } -// VKProviderRoutingMiddleware routes requests to the appropriate provider based on the virtual key -func VKProviderRoutingMiddleware(config *lib.Config, logger schemas.Logger) lib.BifrostHTTPMiddleware { - isGovernanceEnabled := config.LoadedPlugins[governance.PluginName] +func TransportInterceptorMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - if !isGovernanceEnabled { + // Get plugins from config - lock-free read + plugins := config.GetLoadedPlugins() + if len(plugins) == 0 { next(ctx) return } - var virtualKeyValue string - // Extract x-bf-vk header - ctx.Request.Header.All()(func(key, value []byte) bool { - if strings.ToLower(string(key)) == "x-bf-vk" { - virtualKeyValue = string(value) + + // If governance plugin is not loaded, skip interception + hasGovernance := false + for _, p := range plugins { + if p.GetName() == governance.PluginName { + hasGovernance = true + break } - return true - }) - // If no virtual key, continue to next handler - if virtualKeyValue == "" { - next(ctx) - return - } - // Only process POST requests with a body - if string(ctx.Method()) != "POST" { - next(ctx) - return - } - // Get the request body - body := ctx.Request.Body() - if len(body) == 0 { - next(ctx) - return } - // Parse the request body to extract the model field - var requestBody map[string]interface{} - if err := json.Unmarshal(body, &requestBody); err != nil { - // If we can't parse as JSON, continue without modification + if !hasGovernance { next(ctx) return } - // Check if the request has a model field - modelValue, hasModel := requestBody["model"] - if !hasModel { - next(ctx) - return - } - modelStr, ok := modelValue.(string) - if !ok || modelStr == "" { - next(ctx) - return - } - // Check if model already has provider prefix (contains "/") - if strings.Contains(modelStr, "/") { - provider, _ := schemas.ParseModelString(modelStr, "") - // Checking valid provider - if _, ok := config.Providers[provider]; ok { + // Parse headers + headers := make(map[string]string) + originalHeaderNames := make([]string, 0, 16) + ctx.Request.Header.All()(func(key, value []byte) bool { + name := string(key) + headers[name] = string(value) + originalHeaderNames = append(originalHeaderNames, name) + + return true + }) + + // Unmarshal request body + requestBody := make(map[string]any) + bodyBytes := ctx.Request.Body() + if len(bodyBytes) > 0 { + if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { + // If body is not valid JSON, log warning and continue without interception + logger.Warn(fmt.Sprintf("TransportInterceptor: Failed to unmarshal request body: %v", err)) next(ctx) return } } - var virtualKey *configstore.TableVirtualKey - var err error - for _, vk := range config.GovernanceConfig.VirtualKeys { - if vk.Value == virtualKeyValue { - virtualKey = &vk - break - } - } - if virtualKey == nil { - SendError(ctx, fasthttp.StatusBadRequest, "Invalid virtual key", logger) - return - } - if !virtualKey.IsActive { - SendError(ctx, fasthttp.StatusBadRequest, "Virtual key is not active", logger) - next(ctx) - return - } - // Get provider configs for this virtual key - providerConfigs := virtualKey.ProviderConfigs - if len(providerConfigs) == 0 { - // No provider configs, continue without modification - next(ctx) - return - } - allowedProviderConfigs := make([]configstore.TableVirtualKeyProviderConfig, 0) - for _, config := range providerConfigs { - if len(config.AllowedModels) == 0 || slices.Contains(config.AllowedModels, modelStr) { - allowedProviderConfigs = append(allowedProviderConfigs, config) + + // Call TransportInterceptor on all plugins + for _, plugin := range plugins { + modifiedHeaders, modifiedBody, err := plugin.TransportInterceptor(string(ctx.Request.URI().RequestURI()), headers, requestBody) + if err != nil { + logger.Warn(fmt.Sprintf("TransportInterceptor: Plugin '%s' returned error: %v", plugin.GetName(), err)) + // Continue with unmodified headers/body + continue } - } - if len(allowedProviderConfigs) == 0 { - // No allowed provider configs, continue without modification - next(ctx) - return - } - // Weighted random selection from allowed providers for the main model - totalWeight := 0.0 - for _, config := range allowedProviderConfigs { - totalWeight += config.Weight - } - // Generate random number between 0 and totalWeight - randomValue := rand.Float64() * totalWeight - // Select provider based on weighted random selection - var selectedProvider schemas.ModelProvider - currentWeight := 0.0 - for _, config := range allowedProviderConfigs { - currentWeight += config.Weight - if randomValue <= currentWeight { - selectedProvider = schemas.ModelProvider(config.Provider) - break + // Update headers and body with modifications + if modifiedHeaders != nil { + headers = modifiedHeaders } - } - // Fallback: if no provider was selected (shouldn't happen but guard against FP issues) - if selectedProvider == "" && len(allowedProviderConfigs) > 0 { - selectedProvider = schemas.ModelProvider(allowedProviderConfigs[0].Provider) - } - // Update the model field in the request body - requestBody["model"] = string(selectedProvider) + "/" + modelStr - // Check if fallbacks field is already present - _, hasFallbacks := requestBody["fallbacks"] - if !hasFallbacks && len(allowedProviderConfigs) > 1 { - // Sort allowed provider configs by weight (descending) - sort.Slice(allowedProviderConfigs, func(i, j int) bool { - return allowedProviderConfigs[i].Weight > allowedProviderConfigs[j].Weight - }) - - // Filter out the selected provider and create fallbacks array - fallbacks := make([]string, 0, len(allowedProviderConfigs)-1) - for _, config := range allowedProviderConfigs { - if config.Provider != string(selectedProvider) { - fallbacks = append(fallbacks, string(schemas.ModelProvider(config.Provider))+"/"+modelStr) - } + if modifiedBody != nil { + requestBody = modifiedBody } - - // Add fallbacks to request body - requestBody["fallbacks"] = fallbacks } - // Marshal the updated request body back to JSON + // Marshal the body back to JSON updatedBody, err := json.Marshal(requestBody) if err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to marshal updated request body: %v", err), logger) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("TransportInterceptor: Failed to marshal request body: %v", err), logger) return } - // Replace the request body with the updated one ctx.Request.SetBody(updatedBody) + + // Remove headers that were present originally but removed by plugins + for _, name := range originalHeaderNames { + if _, exists := headers[name]; !exists { + ctx.Request.Header.Del(name) + } + } + + // Set modified headers back on the request + for key, value := range headers { + ctx.Request.Header.Set(key, value) + } + next(ctx) } } diff --git a/transports/bifrost-http/handlers/server.go b/transports/bifrost-http/handlers/server.go index a4e82815e..e1111f407 100644 --- a/transports/bifrost-http/handlers/server.go +++ b/transports/bifrost-http/handlers/server.go @@ -16,6 +16,7 @@ import ( "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" @@ -155,6 +156,17 @@ func MarshalPluginConfig[T any](source any) (*T, error) { return nil, fmt.Errorf("invalid config type") } +type GovernanceInMemoryStore struct { + config *lib.Config +} + +func (s *GovernanceInMemoryStore) GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig { + // Use read lock for thread-safe access - no need to copy on hot path + s.config.Mu.RLock() + defer s.config.Mu.RUnlock() + return s.config.Providers +} + // LoadPlugin loads a plugin by name and returns it as type T. func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, pluginConfig any, bifrostConfig *lib.Config) (T, error) { var zero T @@ -182,7 +194,10 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, pluginConfig if err != nil { return zero, fmt.Errorf("failed to marshal governance plugin config: %v", err) } - plugin, err := governance.Init(ctx, governanceConfig, logger, bifrostConfig.ConfigStore, bifrostConfig.GovernanceConfig, bifrostConfig.PricingManager) + inMemoryStore := &GovernanceInMemoryStore{ + config: bifrostConfig, + } + plugin, err := governance.Init(ctx, governanceConfig, logger, bifrostConfig.ConfigStore, bifrostConfig.GovernanceConfig, bifrostConfig.PricingManager, inMemoryStore) if err != nil { return zero, err } @@ -238,14 +253,13 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, pluginConfig func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, error) { var err error plugins := []schemas.Plugin{} - config.LoadedPlugins = make(map[string]bool) + // Initialize telemetry plugin promPlugin, err := LoadPlugin[*telemetry.PrometheusPlugin](ctx, telemetry.PluginName, nil, config) if err != nil { logger.Error("failed to initialize telemetry plugin: %v", err) } else { plugins = append(plugins, promPlugin) - config.LoadedPlugins[telemetry.PluginName] = true } // Initializing logger plugin var loggingPlugin *logging.LoggerPlugin @@ -256,7 +270,6 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, err logger.Error("failed to initialize logging plugin: %v", err) } else { plugins = append(plugins, loggingPlugin) - config.LoadedPlugins[logging.PluginName] = true } } // Initializing governance plugin @@ -270,12 +283,11 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, err logger.Error("failed to initialize governance plugin: %s", err.Error()) } else { plugins = append(plugins, governancePlugin) - config.LoadedPlugins[governance.PluginName] = true } } // Currently we support first party plugins only // Eventually same flow will be used for third party plugins - for _, plugin := range config.Plugins { + for _, plugin := range config.PluginConfigs { if !plugin.Enabled { continue } @@ -284,9 +296,12 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, err logger.Error("failed to load plugin %s: %v", plugin.Name, err) } else { plugins = append(plugins, pluginInstance) - config.LoadedPlugins[plugin.Name] = true } } + + // Atomically publish the plugin state + config.Plugins.Store(&plugins) + return plugins, nil } @@ -324,7 +339,7 @@ func (s *BifrostHTTPServer) ReloadClientConfigFromConfigStore() error { Account: account, InitialPoolSize: s.Config.ClientConfig.InitialPoolSize, DropExcessRequests: s.Config.ClientConfig.DropExcessRequests, - Plugins: s.Plugins, + Plugins: s.Config.GetLoadedPlugins(), MCPConfig: s.Config.MCPConfig, Logger: logger, }) @@ -332,7 +347,8 @@ func (s *BifrostHTTPServer) ReloadClientConfigFromConfigStore() error { return nil } -// ReloadPlugin reloads a plugin with new instance and updates Bifrost core +// ReloadPlugin reloads a plugin with new instance and updates Bifrost core. +// Uses atomic CompareAndSwap with retry loop to handle concurrent updates safely. func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, pluginConfig any) error { logger.Debug("reloading plugin %s", name) newPlugin, err := LoadPlugin[schemas.Plugin](ctx, name, pluginConfig, s.Config) @@ -342,35 +358,70 @@ func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, plugi if err := s.Client.ReloadPlugin(newPlugin); err != nil { return err } - for i, existing := range s.Plugins { - if existing.GetName() == name { - s.Plugins[i] = newPlugin - goto updated + + // CAS retry loop (matching bifrost.go pattern) + for { + oldPlugins := s.Config.Plugins.Load() + oldPluginsSlice := []schemas.Plugin{} + if oldPlugins != nil { + oldPluginsSlice = *oldPlugins } + + // Create new slice with replaced/appended plugin + newPlugins := make([]schemas.Plugin, len(oldPluginsSlice)) + copy(newPlugins, oldPluginsSlice) + + found := false + for i, existing := range newPlugins { + if existing.GetName() == name { + newPlugins[i] = newPlugin + found = true + break + } + } + if !found { + newPlugins = append(newPlugins, newPlugin) + } + + // Atomic compare-and-swap + if s.Config.Plugins.CompareAndSwap(oldPlugins, &newPlugins) { + s.Plugins = newPlugins // Keep BifrostHTTPServer.Plugins in sync + return nil + } + // Retry on contention (extremely rare for plugin updates) } - s.Plugins = append(s.Plugins, newPlugin) -updated: - if s.Config != nil && s.Config.LoadedPlugins != nil { - s.Config.LoadedPlugins[name] = true - } - return nil } // RemovePlugin removes a plugin from the server. +// Uses atomic CompareAndSwap with retry loop to handle concurrent updates safely. func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, name string) error { if err := s.Client.RemovePlugin(name); err != nil { return err } - for i, existing := range s.Plugins { - if existing.GetName() == name { - s.Plugins = append(s.Plugins[:i], s.Plugins[i+1:]...) - break + + // CAS retry loop (matching bifrost.go pattern) + for { + oldPlugins := s.Config.Plugins.Load() + oldPluginsSlice := []schemas.Plugin{} + if oldPlugins != nil { + oldPluginsSlice = *oldPlugins } + + // Create new slice without the removed plugin + newPlugins := make([]schemas.Plugin, 0, len(oldPluginsSlice)) + for _, existing := range oldPluginsSlice { + if existing.GetName() != name { + newPlugins = append(newPlugins, existing) + } + } + + // Atomic compare-and-swap + if s.Config.Plugins.CompareAndSwap(oldPlugins, &newPlugins) { + s.Plugins = newPlugins // Keep BifrostHTTPServer.Plugins in sync + return nil + } + // Retry on contention (extremely rare for plugin updates) } - if s.Config != nil && s.Config.LoadedPlugins != nil { - delete(s.Config.LoadedPlugins, name) - } - return nil } // RegisterRoutes initializes the routes for the Bifrost HTTP server. @@ -522,7 +573,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { } // Create fasthttp server instance s.Server = &fasthttp.Server{ - Handler: CorsMiddleware(s.Config)(VKProviderRoutingMiddleware(s.Config, logger)(s.Router.Handler)), + Handler: CorsMiddleware(s.Config)(TransportInterceptorMiddleware(s.Config)(s.Router.Handler)), MaxRequestBodySize: s.Config.ClientConfig.MaxRequestBodySizeMB * 1024 * 1024, } return nil diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 79b6a6c76..3c10cb9f9 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -11,6 +11,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" @@ -112,8 +113,9 @@ func (cd *ConfigData) UnmarshalJSON(data []byte) error { // - Real-time configuration updates via HTTP API // - Automatic database persistence for all changes // - Support for provider-specific key configurations (Azure, Vertex, Bedrock) +// - Lock-free plugin reads via atomic.Pointer for minimal hot-path latency type Config struct { - mu sync.RWMutex + Mu sync.RWMutex // Exported for direct access from handlers (governance plugin) muMCP sync.RWMutex client *bifrost.Bifrost @@ -133,9 +135,11 @@ type Config struct { // Track which keys come from environment variables EnvKeys map[string][]configstore.EnvKeyInfo - // Plugin configs - Plugins []*schemas.PluginConfig - LoadedPlugins map[string]bool + // Plugin configs - atomic for lock-free reads with CAS updates + Plugins atomic.Pointer[[]schemas.Plugin] + + // Plugin configs from config file/database + PluginConfigs []*schemas.PluginConfig // Pricing manager PricingManager *pricing.PricingManager @@ -179,7 +183,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { EnvKeys: make(map[string][]configstore.EnvKeyInfo), Providers: make(map[schemas.ModelProvider]configstore.ProviderConfig), } - + // Getting absolute path for config file absConfigFilePath, err := filepath.Abs(configFilePath) if err != nil { return nil, fmt.Errorf("failed to get absolute path for config file: %w", err) @@ -330,9 +334,9 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { return nil, fmt.Errorf("failed to get plugins: %w", err) } if plugins == nil { - config.Plugins = []*schemas.PluginConfig{} + config.PluginConfigs = []*schemas.PluginConfig{} } else { - config.Plugins = make([]*schemas.PluginConfig, len(plugins)) + config.PluginConfigs = make([]*schemas.PluginConfig, len(plugins)) for i, plugin := range plugins { pluginConfig := &schemas.PluginConfig{ Name: plugin.Name, @@ -344,7 +348,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { logger.Warn("failed to add provider keys to semantic cache config: %v", err) } } - config.Plugins[i] = pluginConfig + config.PluginConfigs[i] = pluginConfig } } // Loading governance config @@ -687,7 +691,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { logger.Warn("failed to get plugins from store: %v", err) } if plugins != nil { - config.Plugins = make([]*schemas.PluginConfig, len(plugins)) + config.PluginConfigs = make([]*schemas.PluginConfig, len(plugins)) for i, plugin := range plugins { pluginConfig := &schemas.PluginConfig{ Name: plugin.Name, @@ -699,28 +703,28 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { logger.Warn("failed to add provider keys to semantic cache config: %v", err) } } - config.Plugins[i] = pluginConfig + config.PluginConfigs[i] = pluginConfig } } } // If plugins are not present in the store, we will use the config file - if len(config.Plugins) == 0 && len(configData.Plugins) > 0 { + if len(config.PluginConfigs) == 0 && len(configData.Plugins) > 0 { logger.Debug("no plugins found in store, processing from config file") - config.Plugins = configData.Plugins + config.PluginConfigs = configData.Plugins - for i, plugin := range config.Plugins { + for i, plugin := range config.PluginConfigs { if plugin.Name == semanticcache.PluginName { if err := config.AddProviderKeysToSemanticCacheConfig(plugin); err != nil { logger.Warn("failed to add provider keys to semantic cache config: %v", err) } - config.Plugins[i] = plugin + config.PluginConfigs[i] = plugin } } if config.ConfigStore != nil { logger.Debug("updating plugins in store") - for _, plugin := range config.Plugins { + for _, plugin := range config.PluginConfigs { pluginConfigCopy, err := DeepCopy(plugin.Config) if err != nil { logger.Warn("failed to deep copy plugin config, skipping database update: %v", err) @@ -849,8 +853,8 @@ func (s *Config) getRestoredMCPConfig(envVarsByPath map[string]string) *schemas. // // Returns a copy of the configuration to prevent external modifications. func (s *Config) GetProviderConfigRaw(provider schemas.ModelProvider) (*configstore.ProviderConfig, error) { - s.mu.RLock() - defer s.mu.RUnlock() + s.Mu.RLock() + defer s.Mu.RUnlock() config, exists := s.Providers[provider] if !exists { @@ -872,6 +876,34 @@ func (s *Config) ShouldAllowDirectKeys() bool { return s.ClientConfig.AllowDirectKeys } +// GetLoadedPlugins returns the current snapshot of loaded plugins. +// This method is lock-free and safe for concurrent access from hot paths. +// It returns the plugin slice from the atomic pointer, which is safe to iterate +// even if plugins are being updated concurrently. +func (c *Config) GetLoadedPlugins() []schemas.Plugin { + if plugins := c.Plugins.Load(); plugins != nil { + return *plugins + } + return nil +} + +// IsPluginLoaded checks if a plugin with the given name is currently loaded. +// This method is lock-free and safe for concurrent access from hot paths. +// It iterates through the plugin slice (typically 5-10 plugins, ~50ns overhead). +// For small plugin counts, this is faster than maintaining a separate map. +func (c *Config) IsPluginLoaded(name string) bool { + plugins := c.Plugins.Load() + if plugins == nil { + return false + } + for _, p := range *plugins { + if p.GetName() == name { + return true + } + } + return false +} + // GetProviderConfigRedacted retrieves a provider configuration with sensitive values redacted. // This method is intended for external API responses and logging. // @@ -881,8 +913,8 @@ func (s *Config) ShouldAllowDirectKeys() bool { // // Returns a new copy with redacted values that is safe to expose externally. func (s *Config) GetProviderConfigRedacted(provider schemas.ModelProvider) (*configstore.ProviderConfig, error) { - s.mu.RLock() - defer s.mu.RUnlock() + s.Mu.RLock() + defer s.Mu.RUnlock() config, exists := s.Providers[provider] if !exists { @@ -1040,8 +1072,8 @@ func (s *Config) GetProviderConfigRedacted(provider schemas.ModelProvider) (*con // GetAllProviders returns all configured provider names. func (s *Config) GetAllProviders() ([]schemas.ModelProvider, error) { - s.mu.RLock() - defer s.mu.RUnlock() + s.Mu.RLock() + defer s.Mu.RUnlock() providers := make([]schemas.ModelProvider, 0, len(s.Providers)) for provider := range s.Providers { @@ -1060,8 +1092,8 @@ func (s *Config) GetAllProviders() ([]schemas.ModelProvider, error) { // - Stores the processed configuration in memory // - Updates metadata and timestamps func (s *Config) AddProvider(ctx context.Context, provider schemas.ModelProvider, config configstore.ProviderConfig) error { - s.mu.Lock() - defer s.mu.Unlock() + s.Mu.Lock() + defer s.Mu.Unlock() // Check if provider already exists if _, exists := s.Providers[provider]; exists { @@ -1163,8 +1195,8 @@ func (s *Config) AddProvider(ctx context.Context, provider schemas.ModelProvider // - provider: The provider to update // - config: The new configuration func (s *Config) UpdateProviderConfig(ctx context.Context, provider schemas.ModelProvider, config configstore.ProviderConfig) error { - s.mu.Lock() - defer s.mu.Unlock() + s.Mu.Lock() + defer s.Mu.Unlock() // Get existing configuration for validation existingConfig, exists := s.Providers[provider] @@ -1247,8 +1279,8 @@ func (s *Config) UpdateProviderConfig(ctx context.Context, provider schemas.Mode // RemoveProvider removes a provider configuration from memory. func (s *Config) RemoveProvider(ctx context.Context, provider schemas.ModelProvider) error { - s.mu.Lock() - defer s.mu.Unlock() + s.Mu.Lock() + defer s.Mu.Unlock() if _, exists := s.Providers[provider]; !exists { return ErrNotFound @@ -1272,8 +1304,8 @@ func (s *Config) RemoveProvider(ctx context.Context, provider schemas.ModelProvi // GetAllKeys returns the redacted keys func (s *Config) GetAllKeys() ([]configstore.TableKey, error) { - s.mu.RLock() - defer s.mu.RUnlock() + s.Mu.RLock() + defer s.Mu.RUnlock() keys := make([]configstore.TableKey, 0) for providerKey, provider := range s.Providers { diff --git a/transports/bifrost-http/lib/lib.go b/transports/bifrost-http/lib/lib.go index 230ad4b97..4669aca21 100644 --- a/transports/bifrost-http/lib/lib.go +++ b/transports/bifrost-http/lib/lib.go @@ -1,8 +1,12 @@ package lib import ( - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) -var logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo) +var logger schemas.Logger + +// SetLogger sets the logger for the application. +func SetLogger(l schemas.Logger) { + logger = l +} diff --git a/transports/bifrost-http/main.go b/transports/bifrost-http/main.go index c99592f0c..e1c576ef9 100644 --- a/transports/bifrost-http/main.go +++ b/transports/bifrost-http/main.go @@ -62,6 +62,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/handlers" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" ) //go:embed all:ui @@ -124,6 +125,7 @@ func init() { logger.SetOutputType(schemas.LoggerOutputType(server.LogOutputStyle)) logger.SetLevel(schemas.LogLevel(server.LogLevel)) // Setting up logger + lib.SetLogger(logger) handlers.SetLogger(logger) } diff --git a/ui/app/observability/views/observabilityView.tsx b/ui/app/observability/views/observabilityView.tsx index 04e7ff901..37468885a 100644 --- a/ui/app/observability/views/observabilityView.tsx +++ b/ui/app/observability/views/observabilityView.tsx @@ -5,13 +5,23 @@ import { Badge } from "@/components/ui/badge"; import { setSelectedPlugin, useAppDispatch, useAppSelector, useGetPluginsQuery } from "@/lib/store"; import { cn } from "@/lib/utils"; import { useQueryState } from "nuqs"; -import { useEffect } from "react"; +import { useEffect, useMemo } from "react"; import DatadogView from "./plugins/datadogView"; import MaximView from "./plugins/maximView"; import NewrelicView from "./plugins/newRelicView"; import OtelView from "./plugins/otelView"; +import Image from "next/image"; +import { useTheme } from "next-themes"; -const supportedPlatforms = [ +type SupportedPlatform = { + id: string; + name: string; + icon: React.ReactNode; + tag?: string; + disabled?: boolean; +}; + +const supportedPlatformsList = (resolvedTheme: string): SupportedPlatform[] => [ { id: "otel", name: "Open Telemetry", @@ -32,12 +42,12 @@ const supportedPlatforms = [ { id: "maxim", name: "Maxim", - icon: , + icon: Maxim, }, { id: "datadog", name: "Datadog", - icon: , + icon: Datadog, disabled: true, }, { @@ -60,6 +70,10 @@ export default function ObservabilityView() { const [selectedPluginId, setSelectedPluginId] = useQueryState("plugin"); const selectedPlugin = useAppSelector((state) => state.plugin.selectedPlugin); + const { resolvedTheme } = useTheme(); + + const supportedPlatforms = useMemo(() => supportedPlatformsList(resolvedTheme || "light"), [resolvedTheme]); + useEffect(() => { if (!plugins || plugins.length === 0) return; if (!selectedPluginId) {