@@ -4,6 +4,10 @@ package governance
44import (
55 "context"
66 "fmt"
7+ "math/rand/v2"
8+ "slices"
9+ "sort"
10+ "strings"
711
812 bifrost "github.com/maximhq/bifrost/core"
913 "github.com/maximhq/bifrost/core/schemas"
@@ -28,6 +32,10 @@ type Config struct {
2832 IsVkMandatory * bool `json:"is_vk_mandatory"`
2933}
3034
35+ type InMemoryStore interface {
36+ GetConfiguredProviders () map [schemas.ModelProvider ]configstore.ProviderConfig
37+ }
38+
3139// GovernancePlugin implements the main governance plugin with hierarchical budget system
3240type GovernancePlugin struct {
3341 ctx context.Context
@@ -43,19 +51,67 @@ type GovernancePlugin struct {
4351 pricingManager * pricing.PricingManager
4452 logger schemas.Logger
4553
54+ // Transport dependencies
55+ inMemoryStore InMemoryStore
56+
4657 isVkMandatory * bool
4758}
4859
49- // Init creates a new governance plugin with cleanly segregated components
50- // All governance features are enabled by default with optimized settings
51- func Init (ctx context.Context , config * Config , logger schemas.Logger , store configstore.ConfigStore , governanceConfig * configstore.GovernanceConfig , pricingManager * pricing.PricingManager ) (* GovernancePlugin , error ) {
60+ // Init initializes and returns a governance plugin instance.
61+ //
62+ // It wires the core components (store, resolver, tracker), performs a best-effort
63+ // startup reset of expired limits when a persistent `configstore.ConfigStore` is
64+ // provided, and establishes a cancellable plugin context used by background work.
65+ //
66+ // Behavior and defaults:
67+ // - Enables all governance features with optimized defaults.
68+ // - If `store` is nil, the plugin runs in-memory only (no persistence).
69+ // - If `pricingManager` is nil, cost calculation is skipped.
70+ // - `config.IsVkMandatory` controls whether `x-bf-vk` is required in PreHook.
71+ // - `inMemoryStore` is used by TransportInterceptor to validate configured providers
72+ // and build provider-prefixed models; it may be nil. When nil, transport-level
73+ // provider validation/routing is skipped and existing model strings are left
74+ // unchanged. This is safe and recommended when using the plugin directly from
75+ // the Go SDK without the HTTP transport.
76+ //
77+ // Parameters:
78+ // - ctx: base context for the plugin; a child context with cancel is created.
79+ // - config: plugin flags; may be nil.
80+ // - logger: logger used by all subcomponents.
81+ // - store: configuration store used for persistence; may be nil.
82+ // - governanceConfig: initial/seed governance configuration for the store.
83+ // - pricingManager: optional pricing manager to compute request cost.
84+ // - inMemoryStore: provider registry used for routing/validation in transports.
85+ //
86+ // Returns:
87+ // - *GovernancePlugin on success.
88+ // - error if the governance store fails to initialize.
89+ //
90+ // Side effects:
91+ // - Logs warnings when optional dependencies are missing.
92+ // - May perform startup resets via the usage tracker when `store` is non-nil.
93+ func Init (
94+ ctx context.Context ,
95+ config * Config ,
96+ logger schemas.Logger ,
97+ store configstore.ConfigStore ,
98+ governanceConfig * configstore.GovernanceConfig ,
99+ pricingManager * pricing.PricingManager ,
100+ inMemoryStore InMemoryStore ,
101+ ) (* GovernancePlugin , error ) {
52102 if store == nil {
53103 logger .Warn ("governance plugin requires config store to persist data, running in memory only mode" )
54104 }
55105 if pricingManager == nil {
56106 logger .Warn ("governance plugin requires pricing manager to calculate cost, all cost calculations will be skipped." )
57107 }
58108
109+ // Handle nil config - use safe default for IsVkMandatory
110+ var isVkMandatory * bool
111+ if config != nil {
112+ isVkMandatory = config .IsVkMandatory
113+ }
114+
59115 governanceStore , err := NewGovernanceStore (ctx , logger , store , governanceConfig )
60116 if err != nil {
61117 return nil , fmt .Errorf ("failed to initialize governance store: %w" , err )
@@ -84,7 +140,8 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, store conf
84140 configStore : store ,
85141 pricingManager : pricingManager ,
86142 logger : logger ,
87- isVkMandatory : config .IsVkMandatory ,
143+ isVkMandatory : isVkMandatory ,
144+ inMemoryStore : inMemoryStore ,
88145 }
89146
90147 return plugin , nil
@@ -95,6 +152,112 @@ func (p *GovernancePlugin) GetName() string {
95152 return PluginName
96153}
97154
155+ // TransportInterceptor intercepts requests before they are processed (governance decision point)
156+ func (p * GovernancePlugin ) TransportInterceptor (url string , headers map [string ]string , body map [string ]any ) (map [string ]string , map [string ]any , error ) {
157+ var virtualKeyValue string
158+
159+ for header , value := range headers {
160+ if strings .ToLower (string (header )) == "x-bf-vk" {
161+ virtualKeyValue = string (value )
162+ break
163+ }
164+ }
165+ if virtualKeyValue == "" {
166+ return headers , body , nil
167+ }
168+
169+ // Check if the request has a model field
170+ modelValue , hasModel := body ["model" ]
171+ if ! hasModel {
172+ return headers , body , nil
173+ }
174+ modelStr , ok := modelValue .(string )
175+ if ! ok || modelStr == "" {
176+ return headers , body , nil
177+ }
178+
179+ // Check if model already has provider prefix (contains "/")
180+ if strings .Contains (modelStr , "/" ) {
181+ provider , _ := schemas .ParseModelString (modelStr , "" )
182+ // Checking valid provider when store is available; if store is nil,
183+ // assume the prefixed model should be left unchanged.
184+ if p .inMemoryStore != nil {
185+ if _ , ok := p .inMemoryStore .GetConfiguredProviders ()[provider ]; ok {
186+ return headers , body , nil
187+ }
188+ } else {
189+ return headers , body , nil
190+ }
191+ }
192+
193+ virtualKey , ok := p .store .GetVirtualKey (virtualKeyValue )
194+ if ! ok || virtualKey == nil || ! virtualKey .IsActive {
195+ return headers , body , nil
196+ }
197+
198+ // Get provider configs for this virtual key
199+ providerConfigs := virtualKey .ProviderConfigs
200+ if len (providerConfigs ) == 0 {
201+ // No provider configs, continue without modification
202+ return headers , body , nil
203+ }
204+ allowedProviderConfigs := make ([]configstore.TableVirtualKeyProviderConfig , 0 )
205+ for _ , config := range providerConfigs {
206+ if len (config .AllowedModels ) == 0 || slices .Contains (config .AllowedModels , modelStr ) {
207+ allowedProviderConfigs = append (allowedProviderConfigs , config )
208+ }
209+ }
210+ if len (allowedProviderConfigs ) == 0 {
211+ // No allowed provider configs, continue without modification
212+ return headers , body , nil
213+ }
214+ // Weighted random selection from allowed providers for the main model
215+ totalWeight := 0.0
216+ for _ , config := range allowedProviderConfigs {
217+ totalWeight += config .Weight
218+ }
219+ // Generate random number between 0 and totalWeight
220+ randomValue := rand .Float64 () * totalWeight
221+ // Select provider based on weighted random selection
222+ var selectedProvider schemas.ModelProvider
223+ currentWeight := 0.0
224+ for _ , config := range allowedProviderConfigs {
225+ currentWeight += config .Weight
226+ if randomValue <= currentWeight {
227+ selectedProvider = schemas .ModelProvider (config .Provider )
228+ break
229+ }
230+ }
231+ // Fallback: if no provider was selected (shouldn't happen but guard against FP issues)
232+ if selectedProvider == "" && len (allowedProviderConfigs ) > 0 {
233+ selectedProvider = schemas .ModelProvider (allowedProviderConfigs [0 ].Provider )
234+ }
235+ // Update the model field in the request body
236+ body ["model" ] = string (selectedProvider ) + "/" + modelStr
237+
238+ // Check if fallbacks field is already present
239+ _ , hasFallbacks := body ["fallbacks" ]
240+ if ! hasFallbacks && len (allowedProviderConfigs ) > 1 {
241+ // Sort allowed provider configs by weight (descending)
242+ sort .Slice (allowedProviderConfigs , func (i , j int ) bool {
243+ return allowedProviderConfigs [i ].Weight > allowedProviderConfigs [j ].Weight
244+ })
245+
246+ // Filter out the selected provider and create fallbacks array
247+ fallbacks := make ([]string , 0 , len (allowedProviderConfigs )- 1 )
248+ for _ , config := range allowedProviderConfigs {
249+ if config .Provider != string (selectedProvider ) {
250+ fallbacks = append (fallbacks , string (schemas .ModelProvider (config .Provider ))+ "/" + modelStr )
251+ }
252+ }
253+
254+ // Add fallbacks to request body
255+ body ["fallbacks" ] = fallbacks
256+ }
257+
258+ return headers , body , nil
259+ }
260+
98261// PreHook intercepts requests before they are processed (governance decision point)
99262func (p * GovernancePlugin ) PreHook (ctx * context.Context , req * schemas.BifrostRequest ) (* schemas.BifrostRequest , * schemas.PluginShortCircuit , error ) {
100263 // Extract governance headers and virtual key using utility functions
0 commit comments