Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -1096,22 +1096,26 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
func (bifrost *Bifrost) shouldTryFallbacks(req *schemas.BifrostRequest, primaryErr *schemas.BifrostError) bool {
// If no primary error, we succeeded
if primaryErr == nil {
bifrost.logger.Debug("No primary error, we should not try fallbacks")
return false
}

// Handle request cancellation
if primaryErr.Error != nil && primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled {
bifrost.logger.Debug("Request cancelled, we should not try fallbacks")
return false
}

// Check if this is a short-circuit error that doesn't allow fallbacks
// Note: AllowFallbacks = nil is treated as true (allow fallbacks by default)
if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks {
bifrost.logger.Debug("AllowFallbacks is false, we should not try fallbacks")
return false
}

// If no fallbacks configured, return primary error
if len(req.Fallbacks) == 0 {
bifrost.logger.Debug("No fallbacks configured, we should not try fallbacks")
return false
}

Expand Down Expand Up @@ -1201,6 +1205,8 @@ func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, f
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
// It is the wrapper for all non-streaming public API methods.
func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
defer bifrost.releaseBifrostRequest(req)

if err := validateRequest(req); err != nil {
err.ExtraFields = schemas.BifrostErrorExtraFields{
Provider: req.Provider,
Expand All @@ -1215,9 +1221,18 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
ctx = bifrost.ctx
}

bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s and %d fallbacks", req.Provider, req.Model, len(req.Fallbacks)))

// Try the primary provider first
primaryResult, primaryErr := bifrost.tryRequest(req, ctx)

if primaryErr != nil {
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", req.Provider, req.Model, primaryErr))
if len(req.Fallbacks) > 0 {
bifrost.logger.Debug(fmt.Sprintf("Check if we should try %d fallbacks", len(req.Fallbacks)))
}
}

// Check if we should proceed with fallbacks
shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr)
if !shouldTryFallbacks {
Expand All @@ -1226,10 +1241,12 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR

// Try fallbacks in order
for _, fallback := range req.Fallbacks {
bifrost.logger.Debug(fmt.Sprintf("Trying fallback provider %s with model %s", fallback.Provider, fallback.Model))
ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String())

fallbackReq := bifrost.prepareFallbackRequest(req, fallback)
if fallbackReq == nil {
bifrost.logger.Debug(fmt.Sprintf("Fallback provider %s with model %s is nil", fallback.Provider, fallback.Model))
continue
}

Expand All @@ -1255,6 +1272,8 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
// It is the wrapper for all streaming public API methods.
func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
defer bifrost.releaseBifrostRequest(req)

if err := validateRequest(req); err != nil {
err.ExtraFields = schemas.BifrostErrorExtraFields{
Provider: req.Provider,
Expand Down Expand Up @@ -1867,8 +1886,7 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {
bifrost.responseStreamPool.Put(msg.ResponseStream)
}

// Reset and return BifrostRequest to pool
bifrost.releaseBifrostRequest(&msg.BifrostRequest)
// Release of Bifrost Request is handled in handle methods as they are required for fallbacks

// Clear references and return to pool
msg.Response = nil
Expand Down
18 changes: 13 additions & 5 deletions core/schemas/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ type PluginShortCircuit struct {
// PreHooks are executed in the order they are registered.
// PostHooks are executed in the reverse order of PreHooks.
//
// PreHooks and PostHooks can be used to implement custom logic, such as:
// - Rate limiting
// - Caching
// - Logging
// - Monitoring
// Execution order:
// 1. TransportInterceptor (HTTP transport only, modifies raw headers/body before entering Bifrost core)
// 2. PreHook (executed in registration order)
// 3. Provider call
// 4. PostHook (executed in reverse order of PreHooks)
//
// Common use cases: rate limiting, caching, logging, monitoring, request transformation, governance.
//
// Plugin error handling:
// - No Plugin errors are returned to the caller; they are logged as warnings by the Bifrost instance.
Expand All @@ -44,6 +46,12 @@ type Plugin interface {
// GetName returns the name of the plugin.
GetName() string

// TransportInterceptor is called at the HTTP transport layer before requests enter Bifrost core.
// It allows plugins to modify raw HTTP headers and body before transformation into BifrostRequest.
// Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly.
// Returns modified headers, modified body, and any error that occurred during interception.
TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error)

// PreHook is called before a request is processed by a provider.
// It allows plugins to modify the request before it is sent to the provider.
// The context parameter can be used to maintain state across plugin calls.
Expand Down
171 changes: 167 additions & 4 deletions plugins/governance/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ package governance
import (
"context"
"fmt"
"math/rand/v2"
"slices"
"sort"
"strings"

bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
Expand All @@ -28,6 +32,10 @@ type Config struct {
IsVkMandatory *bool `json:"is_vk_mandatory"`
}

type InMemoryStore interface {
GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig
}

// GovernancePlugin implements the main governance plugin with hierarchical budget system
type GovernancePlugin struct {
ctx context.Context
Expand All @@ -43,19 +51,67 @@ type GovernancePlugin struct {
pricingManager *pricing.PricingManager
logger schemas.Logger

// Transport dependencies
inMemoryStore InMemoryStore

isVkMandatory *bool
}

// Init creates a new governance plugin with cleanly segregated components
// All governance features are enabled by default with optimized settings
func Init(ctx context.Context, config *Config, logger schemas.Logger, store configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, pricingManager *pricing.PricingManager) (*GovernancePlugin, error) {
// Init initializes and returns a governance plugin instance.
//
// It wires the core components (store, resolver, tracker), performs a best-effort
// startup reset of expired limits when a persistent `configstore.ConfigStore` is
// provided, and establishes a cancellable plugin context used by background work.
//
// Behavior and defaults:
// - Enables all governance features with optimized defaults.
// - If `store` is nil, the plugin runs in-memory only (no persistence).
// - If `pricingManager` is nil, cost calculation is skipped.
// - `config.IsVkMandatory` controls whether `x-bf-vk` is required in PreHook.
// - `inMemoryStore` is used by TransportInterceptor to validate configured providers
// and build provider-prefixed models; it may be nil. When nil, transport-level
// provider validation/routing is skipped and existing model strings are left
// unchanged. This is safe and recommended when using the plugin directly from
// the Go SDK without the HTTP transport.
//
// Parameters:
// - ctx: base context for the plugin; a child context with cancel is created.
// - config: plugin flags; may be nil.
// - logger: logger used by all subcomponents.
// - store: configuration store used for persistence; may be nil.
// - governanceConfig: initial/seed governance configuration for the store.
// - pricingManager: optional pricing manager to compute request cost.
// - inMemoryStore: provider registry used for routing/validation in transports.
//
// Returns:
// - *GovernancePlugin on success.
// - error if the governance store fails to initialize.
//
// Side effects:
// - Logs warnings when optional dependencies are missing.
// - May perform startup resets via the usage tracker when `store` is non-nil.
func Init(
ctx context.Context,
config *Config,
logger schemas.Logger,
store configstore.ConfigStore,
governanceConfig *configstore.GovernanceConfig,
pricingManager *pricing.PricingManager,
inMemoryStore InMemoryStore,
) (*GovernancePlugin, error) {
if store == nil {
logger.Warn("governance plugin requires config store to persist data, running in memory only mode")
}
if pricingManager == nil {
logger.Warn("governance plugin requires pricing manager to calculate cost, all cost calculations will be skipped.")
}

// Handle nil config - use safe default for IsVkMandatory
var isVkMandatory *bool
if config != nil {
isVkMandatory = config.IsVkMandatory
}

governanceStore, err := NewGovernanceStore(ctx, logger, store, governanceConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize governance store: %w", err)
Expand Down Expand Up @@ -84,7 +140,8 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, store conf
configStore: store,
pricingManager: pricingManager,
logger: logger,
isVkMandatory: config.IsVkMandatory,
isVkMandatory: isVkMandatory,
inMemoryStore: inMemoryStore,
}

return plugin, nil
Expand All @@ -95,6 +152,112 @@ func (p *GovernancePlugin) GetName() string {
return PluginName
}

// TransportInterceptor intercepts requests before they are processed (governance decision point)
func (p *GovernancePlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
var virtualKeyValue string

for header, value := range headers {
if strings.ToLower(string(header)) == "x-bf-vk" {
virtualKeyValue = string(value)
break
}
}
if virtualKeyValue == "" {
return headers, body, nil
}

// Check if the request has a model field
modelValue, hasModel := body["model"]
if !hasModel {
return headers, body, nil
}
modelStr, ok := modelValue.(string)
if !ok || modelStr == "" {
return headers, body, nil
}

// Check if model already has provider prefix (contains "/")
if strings.Contains(modelStr, "/") {
provider, _ := schemas.ParseModelString(modelStr, "")
// Checking valid provider when store is available; if store is nil,
// assume the prefixed model should be left unchanged.
if p.inMemoryStore != nil {
if _, ok := p.inMemoryStore.GetConfiguredProviders()[provider]; ok {
return headers, body, nil
}
} else {
return headers, body, nil
}
}

virtualKey, ok := p.store.GetVirtualKey(virtualKeyValue)
if !ok || virtualKey == nil || !virtualKey.IsActive {
return headers, body, nil
}

// Get provider configs for this virtual key
providerConfigs := virtualKey.ProviderConfigs
if len(providerConfigs) == 0 {
// No provider configs, continue without modification
return headers, body, nil
}
allowedProviderConfigs := make([]configstore.TableVirtualKeyProviderConfig, 0)
for _, config := range providerConfigs {
if len(config.AllowedModels) == 0 || slices.Contains(config.AllowedModels, modelStr) {
allowedProviderConfigs = append(allowedProviderConfigs, config)
}
}
if len(allowedProviderConfigs) == 0 {
// No allowed provider configs, continue without modification
return headers, body, nil
}
// Weighted random selection from allowed providers for the main model
totalWeight := 0.0
for _, config := range allowedProviderConfigs {
totalWeight += config.Weight
}
// Generate random number between 0 and totalWeight
randomValue := rand.Float64() * totalWeight
// Select provider based on weighted random selection
var selectedProvider schemas.ModelProvider
currentWeight := 0.0
for _, config := range allowedProviderConfigs {
currentWeight += config.Weight
if randomValue <= currentWeight {
selectedProvider = schemas.ModelProvider(config.Provider)
break
}
}
// Fallback: if no provider was selected (shouldn't happen but guard against FP issues)
if selectedProvider == "" && len(allowedProviderConfigs) > 0 {
selectedProvider = schemas.ModelProvider(allowedProviderConfigs[0].Provider)
}
// Update the model field in the request body
body["model"] = string(selectedProvider) + "/" + modelStr

// Check if fallbacks field is already present
_, hasFallbacks := body["fallbacks"]
if !hasFallbacks && len(allowedProviderConfigs) > 1 {
// Sort allowed provider configs by weight (descending)
sort.Slice(allowedProviderConfigs, func(i, j int) bool {
return allowedProviderConfigs[i].Weight > allowedProviderConfigs[j].Weight
})

// Filter out the selected provider and create fallbacks array
fallbacks := make([]string, 0, len(allowedProviderConfigs)-1)
for _, config := range allowedProviderConfigs {
if config.Provider != string(selectedProvider) {
fallbacks = append(fallbacks, string(schemas.ModelProvider(config.Provider))+"/"+modelStr)
}
}

// Add fallbacks to request body
body["fallbacks"] = fallbacks
}

return headers, body, nil
}

// PreHook intercepts requests before they are processed (governance decision point)
func (p *GovernancePlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) {
// Extract governance headers and virtual key using utility functions
Expand Down
5 changes: 5 additions & 0 deletions plugins/jsonparser/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ func (p *JsonParserPlugin) GetName() string {
return PluginName
}

// TransportInterceptor is not used for this plugin
func (p *JsonParserPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
return headers, body, nil
}

// PreHook is not used for this plugin as we only process responses
func (p *JsonParserPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) {
return req, nil, nil
Expand Down
5 changes: 5 additions & 0 deletions plugins/logging/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ func (p *LoggerPlugin) GetName() string {
return PluginName
}

// TransportInterceptor is not used for this plugin
func (p *LoggerPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
return headers, body, nil
}

// PreHook is called before a request is processed - FULLY ASYNC, NO DATABASE I/O
func (p *LoggerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) {
if ctx == nil {
Expand Down
5 changes: 5 additions & 0 deletions plugins/maxim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ func (plugin *Plugin) GetName() string {
return PluginName
}

// TransportInterceptor is not used for this plugin
func (plugin *Plugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
return headers, body, nil
}

// getEffectiveLogRepoID determines which single log repo ID to use based on priority:
// 1. Header log repo ID (if provided)
// 2. Default log repo ID from config (if configured)
Expand Down
5 changes: 5 additions & 0 deletions plugins/mocker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,11 @@ func (p *MockerPlugin) GetName() string {
return PluginName
}

// TransportInterceptor is not used for this plugin
func (p *MockerPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
return headers, body, nil
}

// PreHook intercepts requests and applies mocking rules based on configuration
// This is called before the actual provider request and can short-circuit the flow
func (p *MockerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) {
Expand Down
5 changes: 5 additions & 0 deletions plugins/otel/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ func (p *OtelPlugin) GetName() string {
return PluginName
}

// TransportInterceptor is not used for this plugin
func (p *OtelPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
return headers, body, nil
}

// ValidateConfig function for the OTEL plugin
func (p *OtelPlugin) ValidateConfig(config any) (*Config, error) {
var otelConfig Config
Expand Down
Loading