Skip to content

Commit c67b125

Browse files
feat: transport interceptor method added to plugin schema
1 parent f461587 commit c67b125

File tree

16 files changed

+459
-202
lines changed

16 files changed

+459
-202
lines changed

core/bifrost.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,22 +1096,26 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10961096
func (bifrost *Bifrost) shouldTryFallbacks(req *schemas.BifrostRequest, primaryErr *schemas.BifrostError) bool {
10971097
// If no primary error, we succeeded
10981098
if primaryErr == nil {
1099+
bifrost.logger.Debug("No primary error, we should not try fallbacks")
10991100
return false
11001101
}
11011102

11021103
// Handle request cancellation
11031104
if primaryErr.Error != nil && primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled {
1105+
bifrost.logger.Debug("Request cancelled, we should not try fallbacks")
11041106
return false
11051107
}
11061108

11071109
// Check if this is a short-circuit error that doesn't allow fallbacks
11081110
// Note: AllowFallbacks = nil is treated as true (allow fallbacks by default)
11091111
if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks {
1112+
bifrost.logger.Debug("AllowFallbacks is false, we should not try fallbacks")
11101113
return false
11111114
}
11121115

11131116
// If no fallbacks configured, return primary error
11141117
if len(req.Fallbacks) == 0 {
1118+
bifrost.logger.Debug("No fallbacks configured, we should not try fallbacks")
11151119
return false
11161120
}
11171121

@@ -1201,6 +1205,8 @@ func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, f
12011205
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
12021206
// It is the wrapper for all non-streaming public API methods.
12031207
func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
1208+
defer bifrost.releaseBifrostRequest(req)
1209+
12041210
if err := validateRequest(req); err != nil {
12051211
err.ExtraFields = schemas.BifrostErrorExtraFields{
12061212
Provider: req.Provider,
@@ -1215,9 +1221,18 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
12151221
ctx = bifrost.ctx
12161222
}
12171223

1224+
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s and %d fallbacks", req.Provider, req.Model, len(req.Fallbacks)))
1225+
12181226
// Try the primary provider first
12191227
primaryResult, primaryErr := bifrost.tryRequest(req, ctx)
12201228

1229+
if primaryErr != nil {
1230+
bifrost.logger.Debug(fmt.Sprintf("Primary provider %s with model %s returned error: %v", req.Provider, req.Model, primaryErr))
1231+
if len(req.Fallbacks) > 0 {
1232+
bifrost.logger.Debug(fmt.Sprintf("Check if we should try %d fallbacks", len(req.Fallbacks)))
1233+
}
1234+
}
1235+
12211236
// Check if we should proceed with fallbacks
12221237
shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr)
12231238
if !shouldTryFallbacks {
@@ -1226,10 +1241,12 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
12261241

12271242
// Try fallbacks in order
12281243
for _, fallback := range req.Fallbacks {
1244+
bifrost.logger.Debug(fmt.Sprintf("Trying fallback provider %s with model %s", fallback.Provider, fallback.Model))
12291245
ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String())
12301246

12311247
fallbackReq := bifrost.prepareFallbackRequest(req, fallback)
12321248
if fallbackReq == nil {
1249+
bifrost.logger.Debug(fmt.Sprintf("Fallback provider %s with model %s is nil", fallback.Provider, fallback.Model))
12331250
continue
12341251
}
12351252

@@ -1255,6 +1272,8 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR
12551272
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
12561273
// It is the wrapper for all streaming public API methods.
12571274
func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
1275+
defer bifrost.releaseBifrostRequest(req)
1276+
12581277
if err := validateRequest(req); err != nil {
12591278
err.ExtraFields = schemas.BifrostErrorExtraFields{
12601279
Provider: req.Provider,
@@ -1867,8 +1886,7 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {
18671886
bifrost.responseStreamPool.Put(msg.ResponseStream)
18681887
}
18691888

1870-
// Reset and return BifrostRequest to pool
1871-
bifrost.releaseBifrostRequest(&msg.BifrostRequest)
1889+
// Release of Bifrost Request is handled in handle methods as they are required for fallbacks
18721890

18731891
// Clear references and return to pool
18741892
msg.Response = nil

core/schemas/plugin.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ type PluginShortCircuit struct {
1818
// PreHooks are executed in the order they are registered.
1919
// PostHooks are executed in the reverse order of PreHooks.
2020
//
21-
// PreHooks and PostHooks can be used to implement custom logic, such as:
22-
// - Rate limiting
23-
// - Caching
24-
// - Logging
25-
// - Monitoring
21+
// Execution order:
22+
// 1. TransportInterceptor (HTTP transport only, modifies raw headers/body before entering Bifrost core)
23+
// 2. PreHook (executed in registration order)
24+
// 3. Provider call
25+
// 4. PostHook (executed in reverse order of PreHooks)
26+
//
27+
// Common use cases: rate limiting, caching, logging, monitoring, request transformation, governance.
2628
//
2729
// Plugin error handling:
2830
// - No Plugin errors are returned to the caller; they are logged as warnings by the Bifrost instance.
@@ -44,6 +46,12 @@ type Plugin interface {
4446
// GetName returns the name of the plugin.
4547
GetName() string
4648

49+
// TransportInterceptor is called at the HTTP transport layer before requests enter Bifrost core.
50+
// It allows plugins to modify raw HTTP headers and body before transformation into BifrostRequest.
51+
// Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly.
52+
// Returns modified headers, modified body, and any error that occurred during interception.
53+
TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error)
54+
4755
// PreHook is called before a request is processed by a provider.
4856
// It allows plugins to modify the request before it is sent to the provider.
4957
// The context parameter can be used to maintain state across plugin calls.

plugins/governance/main.go

Lines changed: 167 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ package governance
44
import (
55
"context"
66
"fmt"
7+
"math/rand/v2"
8+
"slices"
9+
"sort"
10+
"strings"
711

812
bifrost "github.com/maximhq/bifrost/core"
913
"github.com/maximhq/bifrost/core/schemas"
@@ -28,6 +32,10 @@ type Config struct {
2832
IsVkMandatory *bool `json:"is_vk_mandatory"`
2933
}
3034

35+
type InMemoryStore interface {
36+
GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig
37+
}
38+
3139
// GovernancePlugin implements the main governance plugin with hierarchical budget system
3240
type GovernancePlugin struct {
3341
ctx context.Context
@@ -43,19 +51,67 @@ type GovernancePlugin struct {
4351
pricingManager *pricing.PricingManager
4452
logger schemas.Logger
4553

54+
// Transport dependencies
55+
inMemoryStore InMemoryStore
56+
4657
isVkMandatory *bool
4758
}
4859

49-
// Init creates a new governance plugin with cleanly segregated components
50-
// All governance features are enabled by default with optimized settings
51-
func Init(ctx context.Context, config *Config, logger schemas.Logger, store configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, pricingManager *pricing.PricingManager) (*GovernancePlugin, error) {
60+
// Init initializes and returns a governance plugin instance.
61+
//
62+
// It wires the core components (store, resolver, tracker), performs a best-effort
63+
// startup reset of expired limits when a persistent `configstore.ConfigStore` is
64+
// provided, and establishes a cancellable plugin context used by background work.
65+
//
66+
// Behavior and defaults:
67+
// - Enables all governance features with optimized defaults.
68+
// - If `store` is nil, the plugin runs in-memory only (no persistence).
69+
// - If `pricingManager` is nil, cost calculation is skipped.
70+
// - `config.IsVkMandatory` controls whether `x-bf-vk` is required in PreHook.
71+
// - `inMemoryStore` is used by TransportInterceptor to validate configured providers
72+
// and build provider-prefixed models; it may be nil. When nil, transport-level
73+
// provider validation/routing is skipped and existing model strings are left
74+
// unchanged. This is safe and recommended when using the plugin directly from
75+
// the Go SDK without the HTTP transport.
76+
//
77+
// Parameters:
78+
// - ctx: base context for the plugin; a child context with cancel is created.
79+
// - config: plugin flags; may be nil.
80+
// - logger: logger used by all subcomponents.
81+
// - store: configuration store used for persistence; may be nil.
82+
// - governanceConfig: initial/seed governance configuration for the store.
83+
// - pricingManager: optional pricing manager to compute request cost.
84+
// - inMemoryStore: provider registry used for routing/validation in transports.
85+
//
86+
// Returns:
87+
// - *GovernancePlugin on success.
88+
// - error if the governance store fails to initialize.
89+
//
90+
// Side effects:
91+
// - Logs warnings when optional dependencies are missing.
92+
// - May perform startup resets via the usage tracker when `store` is non-nil.
93+
func Init(
94+
ctx context.Context,
95+
config *Config,
96+
logger schemas.Logger,
97+
store configstore.ConfigStore,
98+
governanceConfig *configstore.GovernanceConfig,
99+
pricingManager *pricing.PricingManager,
100+
inMemoryStore InMemoryStore,
101+
) (*GovernancePlugin, error) {
52102
if store == nil {
53103
logger.Warn("governance plugin requires config store to persist data, running in memory only mode")
54104
}
55105
if pricingManager == nil {
56106
logger.Warn("governance plugin requires pricing manager to calculate cost, all cost calculations will be skipped.")
57107
}
58108

109+
// Handle nil config - use safe default for IsVkMandatory
110+
var isVkMandatory *bool
111+
if config != nil {
112+
isVkMandatory = config.IsVkMandatory
113+
}
114+
59115
governanceStore, err := NewGovernanceStore(ctx, logger, store, governanceConfig)
60116
if err != nil {
61117
return nil, fmt.Errorf("failed to initialize governance store: %w", err)
@@ -84,7 +140,8 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, store conf
84140
configStore: store,
85141
pricingManager: pricingManager,
86142
logger: logger,
87-
isVkMandatory: config.IsVkMandatory,
143+
isVkMandatory: isVkMandatory,
144+
inMemoryStore: inMemoryStore,
88145
}
89146

90147
return plugin, nil
@@ -95,6 +152,112 @@ func (p *GovernancePlugin) GetName() string {
95152
return PluginName
96153
}
97154

155+
// TransportInterceptor intercepts requests before they are processed (governance decision point)
156+
func (p *GovernancePlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
157+
var virtualKeyValue string
158+
159+
for header, value := range headers {
160+
if strings.ToLower(string(header)) == "x-bf-vk" {
161+
virtualKeyValue = string(value)
162+
break
163+
}
164+
}
165+
if virtualKeyValue == "" {
166+
return headers, body, nil
167+
}
168+
169+
// Check if the request has a model field
170+
modelValue, hasModel := body["model"]
171+
if !hasModel {
172+
return headers, body, nil
173+
}
174+
modelStr, ok := modelValue.(string)
175+
if !ok || modelStr == "" {
176+
return headers, body, nil
177+
}
178+
179+
// Check if model already has provider prefix (contains "/")
180+
if strings.Contains(modelStr, "/") {
181+
provider, _ := schemas.ParseModelString(modelStr, "")
182+
// Checking valid provider when store is available; if store is nil,
183+
// assume the prefixed model should be left unchanged.
184+
if p.inMemoryStore != nil {
185+
if _, ok := p.inMemoryStore.GetConfiguredProviders()[provider]; ok {
186+
return headers, body, nil
187+
}
188+
} else {
189+
return headers, body, nil
190+
}
191+
}
192+
193+
virtualKey, ok := p.store.GetVirtualKey(virtualKeyValue)
194+
if !ok || virtualKey == nil || !virtualKey.IsActive {
195+
return headers, body, nil
196+
}
197+
198+
// Get provider configs for this virtual key
199+
providerConfigs := virtualKey.ProviderConfigs
200+
if len(providerConfigs) == 0 {
201+
// No provider configs, continue without modification
202+
return headers, body, nil
203+
}
204+
allowedProviderConfigs := make([]configstore.TableVirtualKeyProviderConfig, 0)
205+
for _, config := range providerConfigs {
206+
if len(config.AllowedModels) == 0 || slices.Contains(config.AllowedModels, modelStr) {
207+
allowedProviderConfigs = append(allowedProviderConfigs, config)
208+
}
209+
}
210+
if len(allowedProviderConfigs) == 0 {
211+
// No allowed provider configs, continue without modification
212+
return headers, body, nil
213+
}
214+
// Weighted random selection from allowed providers for the main model
215+
totalWeight := 0.0
216+
for _, config := range allowedProviderConfigs {
217+
totalWeight += config.Weight
218+
}
219+
// Generate random number between 0 and totalWeight
220+
randomValue := rand.Float64() * totalWeight
221+
// Select provider based on weighted random selection
222+
var selectedProvider schemas.ModelProvider
223+
currentWeight := 0.0
224+
for _, config := range allowedProviderConfigs {
225+
currentWeight += config.Weight
226+
if randomValue <= currentWeight {
227+
selectedProvider = schemas.ModelProvider(config.Provider)
228+
break
229+
}
230+
}
231+
// Fallback: if no provider was selected (shouldn't happen but guard against FP issues)
232+
if selectedProvider == "" && len(allowedProviderConfigs) > 0 {
233+
selectedProvider = schemas.ModelProvider(allowedProviderConfigs[0].Provider)
234+
}
235+
// Update the model field in the request body
236+
body["model"] = string(selectedProvider) + "/" + modelStr
237+
238+
// Check if fallbacks field is already present
239+
_, hasFallbacks := body["fallbacks"]
240+
if !hasFallbacks && len(allowedProviderConfigs) > 1 {
241+
// Sort allowed provider configs by weight (descending)
242+
sort.Slice(allowedProviderConfigs, func(i, j int) bool {
243+
return allowedProviderConfigs[i].Weight > allowedProviderConfigs[j].Weight
244+
})
245+
246+
// Filter out the selected provider and create fallbacks array
247+
fallbacks := make([]string, 0, len(allowedProviderConfigs)-1)
248+
for _, config := range allowedProviderConfigs {
249+
if config.Provider != string(selectedProvider) {
250+
fallbacks = append(fallbacks, string(schemas.ModelProvider(config.Provider))+"/"+modelStr)
251+
}
252+
}
253+
254+
// Add fallbacks to request body
255+
body["fallbacks"] = fallbacks
256+
}
257+
258+
return headers, body, nil
259+
}
260+
98261
// PreHook intercepts requests before they are processed (governance decision point)
99262
func (p *GovernancePlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) {
100263
// Extract governance headers and virtual key using utility functions

plugins/jsonparser/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ func (p *JsonParserPlugin) GetName() string {
8787
return PluginName
8888
}
8989

90+
// TransportInterceptor is not used for this plugin
91+
func (p *JsonParserPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
92+
return headers, body, nil
93+
}
94+
9095
// PreHook is not used for this plugin as we only process responses
9196
func (p *JsonParserPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) {
9297
return req, nil, nil

plugins/logging/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ func (p *LoggerPlugin) GetName() string {
211211
return PluginName
212212
}
213213

214+
// TransportInterceptor is not used for this plugin
215+
func (p *LoggerPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
216+
return headers, body, nil
217+
}
218+
214219
// PreHook is called before a request is processed - FULLY ASYNC, NO DATABASE I/O
215220
func (p *LoggerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) {
216221
if ctx == nil {

plugins/maxim/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ func (plugin *Plugin) GetName() string {
117117
return PluginName
118118
}
119119

120+
// TransportInterceptor is not used for this plugin
121+
func (plugin *Plugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
122+
return headers, body, nil
123+
}
124+
120125
// getEffectiveLogRepoID determines which single log repo ID to use based on priority:
121126
// 1. Header log repo ID (if provided)
122127
// 2. Default log repo ID from config (if configured)

plugins/mocker/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,11 @@ func (p *MockerPlugin) GetName() string {
478478
return PluginName
479479
}
480480

481+
// TransportInterceptor is not used for this plugin
482+
func (p *MockerPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
483+
return headers, body, nil
484+
}
485+
481486
// PreHook intercepts requests and applies mocking rules based on configuration
482487
// This is called before the actual provider request and can short-circuit the flow
483488
func (p *MockerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) {

plugins/otel/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ func (p *OtelPlugin) GetName() string {
110110
return PluginName
111111
}
112112

113+
// TransportInterceptor is not used for this plugin
114+
func (p *OtelPlugin) TransportInterceptor(url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) {
115+
return headers, body, nil
116+
}
117+
113118
// ValidateConfig function for the OTEL plugin
114119
func (p *OtelPlugin) ValidateConfig(config any) (*Config, error) {
115120
var otelConfig Config

0 commit comments

Comments
 (0)