diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 9a40dcd..a9662db 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -92,6 +92,13 @@ def list_bedrock_models() -> dict: response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]] + # List application inference profile ARN defined by user, such as MAP tagged + app_profile_dict = {} + response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="APPLICATION") + for p in response["inferenceProfileSummaries"]: + model = p['models'][0]['modelArn'].split('/')[1] + app_profile_dict[model] = p["inferenceProfileArn"] + # List foundation models, only cares about text outputs here. response = bedrock_client.list_foundation_models(byOutputModality="TEXT") @@ -115,6 +122,10 @@ def list_bedrock_models() -> dict: if profile_id in profile_list: model_list[profile_id] = {"modalities": input_modalities} + # Add Appilication inference profile list + if model_id in app_profile_dict: + model_list[app_profile_dict[model_id]] = {"modalities": input_modalities} + except Exception as e: logger.error(f"Unable to list models: {str(e)}")