Skip to content

Commit 8618bf8

Browse files
committed
feat: added list models request
1 parent f3e74fa commit 8618bf8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2097
-153
lines changed

core/bifrost.go

Lines changed: 221 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,22 @@ type ChannelMessage struct {
3333
// It handles request routing, provider management, and response processing.
3434
type Bifrost struct {
3535
ctx context.Context
36-
account schemas.Account // account interface
37-
plugins atomic.Pointer[[]schemas.Plugin] // list of plugins
38-
requestQueues sync.Map // provider request queues (thread-safe)
39-
waitGroups sync.Map // wait groups for each provider (thread-safe)
40-
providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
41-
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
42-
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
43-
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
44-
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
45-
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
46-
bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
47-
logger schemas.Logger // logger instance, default logger is used if not provided
48-
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
49-
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
50-
keySelector schemas.KeySelector // Custom key selector function
36+
account schemas.Account // account interface
37+
plugins atomic.Pointer[[]schemas.Plugin] // list of plugins
38+
providers atomic.Pointer[[]schemas.Provider] // list of providers
39+
requestQueues sync.Map // provider request queues (thread-safe)
40+
waitGroups sync.Map // wait groups for each provider (thread-safe)
41+
providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
42+
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
43+
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
44+
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
45+
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
46+
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
47+
bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
48+
logger schemas.Logger // logger instance, default logger is used if not provided
49+
mcpManager *MCPManager // MCP integration manager (nil if MCP not configured)
50+
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
51+
keySelector schemas.KeySelector // Custom key selector function
5152
}
5253

5354
// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
@@ -91,6 +92,10 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
9192
keySelector: config.KeySelector,
9293
}
9394
bifrost.plugins.Store(&config.Plugins)
95+
96+
// Initialize providers slice
97+
bifrost.providers.Store(&[]schemas.Provider{})
98+
9499
bifrost.dropExcessRequests.Store(config.DropExcessRequests)
95100

96101
if bifrost.keySelector == nil {
@@ -203,6 +208,157 @@ func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error {
203208

204209
// PUBLIC API METHODS
205210

211+
// ListModelsRequest sends a list models request to the specified provider.
212+
func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
213+
if req == nil {
214+
return nil, &schemas.BifrostError{
215+
IsBifrostError: false,
216+
Error: &schemas.ErrorField{
217+
Message: "list models request is nil",
218+
},
219+
}
220+
}
221+
if req.Provider == "" {
222+
return nil, &schemas.BifrostError{
223+
IsBifrostError: false,
224+
Error: &schemas.ErrorField{
225+
Message: "provider is required for list models request",
226+
},
227+
}
228+
}
229+
230+
request := &schemas.BifrostListModelsRequest{
231+
Provider: req.Provider,
232+
PageSize: req.PageSize,
233+
PageToken: req.PageToken,
234+
ExtraParams: req.ExtraParams,
235+
}
236+
237+
provider := bifrost.getProviderByKey(req.Provider)
238+
if provider == nil {
239+
return nil, &schemas.BifrostError{
240+
IsBifrostError: false,
241+
Error: &schemas.ErrorField{
242+
Message: "provider not found for list models request",
243+
},
244+
}
245+
}
246+
247+
// Determine the base provider type for key requirement checks
248+
baseProvider := req.Provider
249+
providerConfig, err := bifrost.account.GetConfigForProvider(req.Provider)
250+
if err == nil && providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" {
251+
baseProvider = providerConfig.CustomProviderConfig.BaseProviderType
252+
}
253+
254+
// Get API key for the provider if required
255+
key := schemas.Key{}
256+
if providerRequiresKey(baseProvider) {
257+
key, err = bifrost.selectKeyFromProviderForModel(&ctx, req.Provider, "", baseProvider, schemas.ListModelsRequest)
258+
if err != nil {
259+
return nil, &schemas.BifrostError{
260+
IsBifrostError: false,
261+
Error: &schemas.ErrorField{
262+
Message: err.Error(),
263+
Error: err,
264+
},
265+
}
266+
}
267+
}
268+
269+
response, bifrostErr := provider.ListModels(ctx, key, request)
270+
if bifrostErr != nil {
271+
return nil, bifrostErr
272+
}
273+
return response, nil
274+
}
275+
276+
// ListAllModels lists all models from all configured providers.
277+
// It accumulates responses from all providers with a limit of 1000 per provider to get all results.
278+
func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
279+
startTime := time.Now()
280+
281+
providerKeys, err := bifrost.account.GetConfiguredProviders()
282+
if err != nil {
283+
return nil, &schemas.BifrostError{
284+
IsBifrostError: false,
285+
Error: &schemas.ErrorField{
286+
Message: "failed to get configured providers",
287+
Error: err,
288+
},
289+
}
290+
}
291+
292+
// Accumulate all models from all providers
293+
allModels := make([]schemas.Model, 0)
294+
var firstError *schemas.BifrostError
295+
296+
for _, providerKey := range providerKeys {
297+
if strings.TrimSpace(string(providerKey)) == "" {
298+
continue
299+
}
300+
301+
// Create request for this provider with limit of 1000
302+
providerRequest := &schemas.BifrostListModelsRequest{
303+
Provider: providerKey,
304+
PageSize: 1000,
305+
}
306+
307+
// Get all pages for this provider
308+
for {
309+
response, bifrostErr := bifrost.ListModelsRequest(ctx, providerRequest)
310+
if bifrostErr != nil {
311+
// Log the error but continue with other providers
312+
bifrost.logger.Warn(fmt.Sprintf("failed to list models for provider %s: %v", providerKey, bifrostErr.Error.Message))
313+
if firstError == nil {
314+
firstError = bifrostErr
315+
}
316+
break
317+
}
318+
319+
if response != nil {
320+
if len(response.Data) > 0 {
321+
allModels = append(allModels, response.Data...)
322+
}
323+
}
324+
325+
// Check if there are more pages
326+
if response.NextPageToken == "" {
327+
break
328+
}
329+
330+
// Set the page token for the next request
331+
providerRequest.PageToken = response.NextPageToken
332+
}
333+
}
334+
335+
// If we couldn't get any models from any provider, return the first error
336+
if len(allModels) == 0 && firstError != nil {
337+
return nil, firstError
338+
}
339+
340+
// Sort models alphabetically by ID
341+
sort.Slice(allModels, func(i, j int) bool {
342+
return allModels[i].ID < allModels[j].ID
343+
})
344+
345+
// Calculate total elapsed time
346+
elapsedTime := time.Since(startTime).Milliseconds()
347+
348+
// Return aggregated response with accumulated latency
349+
response := &schemas.BifrostListModelsResponse{
350+
Data: allModels,
351+
ExtraFields: schemas.BifrostResponseExtraFields{
352+
RequestType: schemas.ListModelsRequest,
353+
Latency: elapsedTime,
354+
},
355+
}
356+
357+
response = schemas.ApplyPagination(response, request.PageSize, request.PageToken)
358+
359+
return response, nil
360+
}
361+
206362
// TextCompletionRequest sends a text completion request to the specified provider.
207363
func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
208364
if req == nil {
@@ -1039,6 +1195,16 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
10391195

10401196
waitGroupValue, _ := bifrost.waitGroups.Load(providerKey)
10411197
currentWaitGroup := waitGroupValue.(*sync.WaitGroup)
1198+
// Add provider to list of providers
1199+
loadedProviders := bifrost.providers.Load()
1200+
if loadedProviders == nil {
1201+
// Initialize if somehow nil
1202+
emptyProviders := make([]schemas.Provider, 0)
1203+
loadedProviders = &emptyProviders
1204+
}
1205+
currentProviders := *loadedProviders
1206+
updatedProviders := append(currentProviders, provider)
1207+
bifrost.providers.Store(&updatedProviders)
10421208

10431209
for range providerConfig.ConcurrencyAndBufferSize.Concurrency {
10441210
currentWaitGroup.Add(1)
@@ -1092,6 +1258,23 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha
10921258
return queue, nil
10931259
}
10941260

1261+
// getProviderByKey retrieves a provider instance from the providers array by its provider key.
1262+
// Returns the provider if found, or nil if no provider with the given key exists.
1263+
func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider {
1264+
providers := bifrost.providers.Load()
1265+
if providers == nil {
1266+
return nil
1267+
}
1268+
1269+
for _, provider := range *providers {
1270+
if provider.GetProviderKey() == providerKey {
1271+
return provider
1272+
}
1273+
}
1274+
1275+
return nil
1276+
}
1277+
10951278
// CORE INTERNAL LOGIC
10961279

10971280
// shouldTryFallbacks handles the primary error and returns true if we should proceed with fallbacks, false if we should return immediately
@@ -1585,7 +1768,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
15851768
key := schemas.Key{}
15861769
if providerRequiresKey(baseProvider) {
15871770
// Use the custom provider name for actual key selection, but pass base provider type for key validation
1588-
key, err = bifrost.selectKeyFromProviderForModel(&req.Context, provider.GetProviderKey(), model, baseProvider)
1771+
key, err = bifrost.selectKeyFromProviderForModel(&req.Context, provider.GetProviderKey(), model, baseProvider, req.RequestType)
15891772
if err != nil {
15901773
bifrost.logger.Warn("error selecting key for model %s: %v", model, err)
15911774
req.Err <- schemas.BifrostError{
@@ -1950,7 +2133,7 @@ func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) {
19502133

19512134
// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
19522135
// It uses weighted random selection if multiple keys are available.
1953-
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) {
2136+
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider, requestType schemas.RequestType) (schemas.Key, error) {
19542137
// Check if key has been set in the context explicitly
19552138
if ctx != nil {
19562139
key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
@@ -1970,28 +2153,31 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
19702153

19712154
// filter out keys which dont support the model, if the key has no models, it is supported for all models
19722155
var supportedKeys []schemas.Key
1973-
for _, key := range keys {
1974-
modelSupported := (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType))) || len(key.Models) == 0
1975-
1976-
// Additional deployment checks for Azure and Bedrock
1977-
deploymentSupported := true
1978-
if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil {
1979-
// For Azure, check if deployment exists for this model
1980-
if len(key.AzureKeyConfig.Deployments) > 0 {
1981-
_, deploymentSupported = key.AzureKeyConfig.Deployments[model]
1982-
}
1983-
} else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil {
1984-
// For Bedrock, check if deployment exists for this model
1985-
if len(key.BedrockKeyConfig.Deployments) > 0 {
1986-
_, deploymentSupported = key.BedrockKeyConfig.Deployments[model]
2156+
if requestType == schemas.ListModelsRequest {
2157+
supportedKeys = keys
2158+
} else {
2159+
for _, key := range keys {
2160+
modelSupported := (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType))) || len(key.Models) == 0
2161+
2162+
// Additional deployment checks for Azure and Bedrock
2163+
deploymentSupported := true
2164+
if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil {
2165+
// For Azure, check if deployment exists for this model
2166+
if len(key.AzureKeyConfig.Deployments) > 0 {
2167+
_, deploymentSupported = key.AzureKeyConfig.Deployments[model]
2168+
}
2169+
} else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil {
2170+
// For Bedrock, check if deployment exists for this model
2171+
if len(key.BedrockKeyConfig.Deployments) > 0 {
2172+
_, deploymentSupported = key.BedrockKeyConfig.Deployments[model]
2173+
}
19872174
}
1988-
}
19892175

1990-
if modelSupported && deploymentSupported {
1991-
supportedKeys = append(supportedKeys, key)
2176+
if modelSupported && deploymentSupported {
2177+
supportedKeys = append(supportedKeys, key)
2178+
}
19922179
}
19932180
}
1994-
19952181
if len(supportedKeys) == 0 {
19962182
if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock {
19972183
return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model)

0 commit comments

Comments
 (0)