@@ -33,21 +33,22 @@ type ChannelMessage struct {
3333// It handles request routing, provider management, and response processing.
3434type Bifrost struct {
3535 ctx context.Context
36- account schemas.Account // account interface
37- plugins atomic.Pointer [[]schemas.Plugin ] // list of plugins
38- requestQueues sync.Map // provider request queues (thread-safe)
39- waitGroups sync.Map // wait groups for each provider (thread-safe)
40- providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
41- channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
42- responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
43- errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
44- responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
45- pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
46- bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
47- logger schemas.Logger // logger instance, default logger is used if not provided
48- mcpManager * MCPManager // MCP integration manager (nil if MCP not configured)
49- dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
50- keySelector schemas.KeySelector // Custom key selector function
36+ account schemas.Account // account interface
37+ plugins atomic.Pointer [[]schemas.Plugin ] // list of plugins
38+ providers atomic.Pointer [[]schemas.Provider ] // list of providers
39+ requestQueues sync.Map // provider request queues (thread-safe)
40+ waitGroups sync.Map // wait groups for each provider (thread-safe)
41+ providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
42+ channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
43+ responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
44+ errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
45+ responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
46+ pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
47+ bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
48+ logger schemas.Logger // logger instance, default logger is used if not provided
49+ mcpManager * MCPManager // MCP integration manager (nil if MCP not configured)
50+ dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
51+ keySelector schemas.KeySelector // Custom key selector function
5152}
5253
5354// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
@@ -91,6 +92,10 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
9192 keySelector : config .KeySelector ,
9293 }
9394 bifrost .plugins .Store (& config .Plugins )
95+
96+ // Initialize providers slice
97+ bifrost .providers .Store (& []schemas.Provider {})
98+
9499 bifrost .dropExcessRequests .Store (config .DropExcessRequests )
95100
96101 if bifrost .keySelector == nil {
@@ -203,6 +208,157 @@ func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error {
203208
204209// PUBLIC API METHODS
205210
211+ // ListModelsRequest sends a list models request to the specified provider.
212+ func (bifrost * Bifrost ) ListModelsRequest (ctx context.Context , req * schemas.BifrostListModelsRequest ) (* schemas.BifrostListModelsResponse , * schemas.BifrostError ) {
213+ if req == nil {
214+ return nil , & schemas.BifrostError {
215+ IsBifrostError : false ,
216+ Error : & schemas.ErrorField {
217+ Message : "list models request is nil" ,
218+ },
219+ }
220+ }
221+ if req .Provider == "" {
222+ return nil , & schemas.BifrostError {
223+ IsBifrostError : false ,
224+ Error : & schemas.ErrorField {
225+ Message : "provider is required for list models request" ,
226+ },
227+ }
228+ }
229+
230+ request := & schemas.BifrostListModelsRequest {
231+ Provider : req .Provider ,
232+ PageSize : req .PageSize ,
233+ PageToken : req .PageToken ,
234+ ExtraParams : req .ExtraParams ,
235+ }
236+
237+ provider := bifrost .getProviderByKey (req .Provider )
238+ if provider == nil {
239+ return nil , & schemas.BifrostError {
240+ IsBifrostError : false ,
241+ Error : & schemas.ErrorField {
242+ Message : "provider not found for list models request" ,
243+ },
244+ }
245+ }
246+
247+ // Determine the base provider type for key requirement checks
248+ baseProvider := req .Provider
249+ providerConfig , err := bifrost .account .GetConfigForProvider (req .Provider )
250+ if err == nil && providerConfig .CustomProviderConfig != nil && providerConfig .CustomProviderConfig .BaseProviderType != "" {
251+ baseProvider = providerConfig .CustomProviderConfig .BaseProviderType
252+ }
253+
254+ // Get API key for the provider if required
255+ key := schemas.Key {}
256+ if providerRequiresKey (baseProvider ) {
257+ key , err = bifrost .selectKeyFromProviderForModel (& ctx , req .Provider , "" , baseProvider , schemas .ListModelsRequest )
258+ if err != nil {
259+ return nil , & schemas.BifrostError {
260+ IsBifrostError : false ,
261+ Error : & schemas.ErrorField {
262+ Message : err .Error (),
263+ Error : err ,
264+ },
265+ }
266+ }
267+ }
268+
269+ response , bifrostErr := provider .ListModels (ctx , key , request )
270+ if bifrostErr != nil {
271+ return nil , bifrostErr
272+ }
273+ return response , nil
274+ }
275+
276+ // ListAllModels lists all models from all configured providers.
277+ // It accumulates responses from all providers with a limit of 1000 per provider to get all results.
278+ func (bifrost * Bifrost ) ListAllModels (ctx context.Context , request * schemas.BifrostListModelsRequest ) (* schemas.BifrostListModelsResponse , * schemas.BifrostError ) {
279+ startTime := time .Now ()
280+
281+ providerKeys , err := bifrost .account .GetConfiguredProviders ()
282+ if err != nil {
283+ return nil , & schemas.BifrostError {
284+ IsBifrostError : false ,
285+ Error : & schemas.ErrorField {
286+ Message : "failed to get configured providers" ,
287+ Error : err ,
288+ },
289+ }
290+ }
291+
292+ // Accumulate all models from all providers
293+ allModels := make ([]schemas.Model , 0 )
294+ var firstError * schemas.BifrostError
295+
296+ for _ , providerKey := range providerKeys {
297+ if strings .TrimSpace (string (providerKey )) == "" {
298+ continue
299+ }
300+
301+ // Create request for this provider with limit of 1000
302+ providerRequest := & schemas.BifrostListModelsRequest {
303+ Provider : providerKey ,
304+ PageSize : 1000 ,
305+ }
306+
307+ // Get all pages for this provider
308+ for {
309+ response , bifrostErr := bifrost .ListModelsRequest (ctx , providerRequest )
310+ if bifrostErr != nil {
311+ // Log the error but continue with other providers
312+ bifrost .logger .Warn (fmt .Sprintf ("failed to list models for provider %s: %v" , providerKey , bifrostErr .Error .Message ))
313+ if firstError == nil {
314+ firstError = bifrostErr
315+ }
316+ break
317+ }
318+
319+ if response != nil {
320+ if len (response .Data ) > 0 {
321+ allModels = append (allModels , response .Data ... )
322+ }
323+ }
324+
325+ // Check if there are more pages
326+ if response .NextPageToken == "" {
327+ break
328+ }
329+
330+ // Set the page token for the next request
331+ providerRequest .PageToken = response .NextPageToken
332+ }
333+ }
334+
335+ // If we couldn't get any models from any provider, return the first error
336+ if len (allModels ) == 0 && firstError != nil {
337+ return nil , firstError
338+ }
339+
340+ // Sort models alphabetically by ID
341+ sort .Slice (allModels , func (i , j int ) bool {
342+ return allModels [i ].ID < allModels [j ].ID
343+ })
344+
345+ // Calculate total elapsed time
346+ elapsedTime := time .Since (startTime ).Milliseconds ()
347+
348+ // Return aggregated response with accumulated latency
349+ response := & schemas.BifrostListModelsResponse {
350+ Data : allModels ,
351+ ExtraFields : schemas.BifrostResponseExtraFields {
352+ RequestType : schemas .ListModelsRequest ,
353+ Latency : elapsedTime ,
354+ },
355+ }
356+
357+ response = schemas .ApplyPagination (response , request .PageSize , request .PageToken )
358+
359+ return response , nil
360+ }
361+
206362// TextCompletionRequest sends a text completion request to the specified provider.
207363func (bifrost * Bifrost ) TextCompletionRequest (ctx context.Context , req * schemas.BifrostTextCompletionRequest ) (* schemas.BifrostTextCompletionResponse , * schemas.BifrostError ) {
208364 if req == nil {
@@ -1039,6 +1195,16 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10391195
10401196 waitGroupValue , _ := bifrost .waitGroups .Load (providerKey )
10411197 currentWaitGroup := waitGroupValue .(* sync.WaitGroup )
1198+ // Add provider to list of providers
1199+ loadedProviders := bifrost .providers .Load ()
1200+ if loadedProviders == nil {
1201+ // Initialize if somehow nil
1202+ emptyProviders := make ([]schemas.Provider , 0 )
1203+ loadedProviders = & emptyProviders
1204+ }
1205+ currentProviders := * loadedProviders
1206+ updatedProviders := append (currentProviders , provider )
1207+ bifrost .providers .Store (& updatedProviders )
10421208
10431209 for range providerConfig .ConcurrencyAndBufferSize .Concurrency {
10441210 currentWaitGroup .Add (1 )
@@ -1092,6 +1258,23 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10921258 return queue , nil
10931259}
10941260
1261+ // getProviderByKey retrieves a provider instance from the providers array by its provider key.
1262+ // Returns the provider if found, or nil if no provider with the given key exists.
1263+ func (bifrost * Bifrost ) getProviderByKey (providerKey schemas.ModelProvider ) schemas.Provider {
1264+ providers := bifrost .providers .Load ()
1265+ if providers == nil {
1266+ return nil
1267+ }
1268+
1269+ for _ , provider := range * providers {
1270+ if provider .GetProviderKey () == providerKey {
1271+ return provider
1272+ }
1273+ }
1274+
1275+ return nil
1276+ }
1277+
10951278// CORE INTERNAL LOGIC
10961279
10971280// shouldTryFallbacks handles the primary error and returns true if we should proceed with fallbacks, false if we should return immediately
@@ -1585,7 +1768,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15851768 key := schemas.Key {}
15861769 if providerRequiresKey (baseProvider ) {
15871770 // Use the custom provider name for actual key selection, but pass base provider type for key validation
1588- key , err = bifrost .selectKeyFromProviderForModel (& req .Context , provider .GetProviderKey (), model , baseProvider )
1771+ key , err = bifrost .selectKeyFromProviderForModel (& req .Context , provider .GetProviderKey (), model , baseProvider , req . RequestType )
15891772 if err != nil {
15901773 bifrost .logger .Warn ("error selecting key for model %s: %v" , model , err )
15911774 req .Err <- schemas.BifrostError {
@@ -1950,7 +2133,7 @@ func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) {
19502133
19512134// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
19522135// It uses weighted random selection if multiple keys are available.
1953- func (bifrost * Bifrost ) selectKeyFromProviderForModel (ctx * context.Context , providerKey schemas.ModelProvider , model string , baseProviderType schemas.ModelProvider ) (schemas.Key , error ) {
2136+ func (bifrost * Bifrost ) selectKeyFromProviderForModel (ctx * context.Context , providerKey schemas.ModelProvider , model string , baseProviderType schemas.ModelProvider , requestType schemas. RequestType ) (schemas.Key , error ) {
19542137 // Check if key has been set in the context explicitly
19552138 if ctx != nil {
19562139 key , ok := (* ctx ).Value (schemas .BifrostContextKeyDirectKey ).(schemas.Key )
@@ -1970,28 +2153,31 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19702153
19712154 // filter out keys which dont support the model, if the key has no models, it is supported for all models
19722155 var supportedKeys []schemas.Key
1973- for _ , key := range keys {
1974- modelSupported := (slices .Contains (key .Models , model ) && (strings .TrimSpace (key .Value ) != "" || canProviderKeyValueBeEmpty (baseProviderType ))) || len (key .Models ) == 0
1975-
1976- // Additional deployment checks for Azure and Bedrock
1977- deploymentSupported := true
1978- if baseProviderType == schemas .Azure && key .AzureKeyConfig != nil {
1979- // For Azure, check if deployment exists for this model
1980- if len (key .AzureKeyConfig .Deployments ) > 0 {
1981- _ , deploymentSupported = key .AzureKeyConfig .Deployments [model ]
1982- }
1983- } else if baseProviderType == schemas .Bedrock && key .BedrockKeyConfig != nil {
1984- // For Bedrock, check if deployment exists for this model
1985- if len (key .BedrockKeyConfig .Deployments ) > 0 {
1986- _ , deploymentSupported = key .BedrockKeyConfig .Deployments [model ]
2156+ if requestType == schemas .ListModelsRequest {
2157+ supportedKeys = keys
2158+ } else {
2159+ for _ , key := range keys {
2160+ modelSupported := (slices .Contains (key .Models , model ) && (strings .TrimSpace (key .Value ) != "" || canProviderKeyValueBeEmpty (baseProviderType ))) || len (key .Models ) == 0
2161+
2162+ // Additional deployment checks for Azure and Bedrock
2163+ deploymentSupported := true
2164+ if baseProviderType == schemas .Azure && key .AzureKeyConfig != nil {
2165+ // For Azure, check if deployment exists for this model
2166+ if len (key .AzureKeyConfig .Deployments ) > 0 {
2167+ _ , deploymentSupported = key .AzureKeyConfig .Deployments [model ]
2168+ }
2169+ } else if baseProviderType == schemas .Bedrock && key .BedrockKeyConfig != nil {
2170+ // For Bedrock, check if deployment exists for this model
2171+ if len (key .BedrockKeyConfig .Deployments ) > 0 {
2172+ _ , deploymentSupported = key .BedrockKeyConfig .Deployments [model ]
2173+ }
19872174 }
1988- }
19892175
1990- if modelSupported && deploymentSupported {
1991- supportedKeys = append (supportedKeys , key )
2176+ if modelSupported && deploymentSupported {
2177+ supportedKeys = append (supportedKeys , key )
2178+ }
19922179 }
19932180 }
1994-
19952181 if len (supportedKeys ) == 0 {
19962182 if baseProviderType == schemas .Azure || baseProviderType == schemas .Bedrock {
19972183 return schemas.Key {}, fmt .Errorf ("no keys found that support model/deployment: %s" , model )
0 commit comments