diff --git a/README.md b/README.md index c82f70d2..158f78bb 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ If you find this GitHub repository useful, please consider giving it a free star - [x] Support Embedding API - [x] Support Multimodal API - [x] Support Cross-Region Inference +- [x] Support Application Inference Profiles (**new**) - [x] Support Reasoning (**new**) Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs. @@ -148,7 +149,48 @@ print(completion.choices[0].message.content) Please check [Usage Guide](./docs/Usage.md) for more details about how to use embedding API, multimodal API and tool call. +### Application Inference Profiles +This proxy now supports **Application Inference Profiles**, which allow you to track usage and costs for your model invocations. You can use application inference profiles created in your AWS account for cost tracking and monitoring purposes. + +**Using Application Inference Profiles:** + +```bash +# Use an application inference profile ARN as the model ID +curl $OPENAI_BASE_URL/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id", + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ] + }' +``` + +**SDK Usage with Application Inference Profiles:** + +```python +from openai import OpenAI + +client = OpenAI() +completion = client.chat.completions.create( + model="arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id", + messages=[{"role": "user", "content": "Hello!"}], +) + +print(completion.choices[0].message.content) +``` + +**Benefits of Application Inference Profiles:** +- **Cost Tracking**: Track usage and costs for specific applications or use cases +- **Usage Monitoring**: Monitor model invocation metrics through CloudWatch +- **Tag-based Cost Allocation**: Use AWS cost allocation tags for detailed billing analysis + +For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html). ## Other Examples diff --git a/deployment/BedrockProxy.template b/deployment/BedrockProxy.template index 17387dfb..5d432671 100644 --- a/deployment/BedrockProxy.template +++ b/deployment/BedrockProxy.template @@ -151,6 +151,7 @@ Resources: Resource: - arn:aws:bedrock:*::foundation-model/* - arn:aws:bedrock:*:*:inference-profile/* + - arn:aws:bedrock:*:*:application-inference-profile/* - Action: - secretsmanager:GetSecretValue - secretsmanager:DescribeSecret @@ -185,6 +186,7 @@ Resources: Ref: DefaultModelId DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3 ENABLE_CROSS_REGION_INFERENCE: "true" + ENABLE_APPLICATION_INFERENCE_PROFILES: "true" MemorySize: 1024 PackageType: Image Role: diff --git a/deployment/BedrockProxyFargate.template b/deployment/BedrockProxyFargate.template index bae785cc..8299a8ad 100644 --- a/deployment/BedrockProxyFargate.template +++ b/deployment/BedrockProxyFargate.template @@ -193,6 +193,7 @@ Resources: Resource: - arn:aws:bedrock:*::foundation-model/* - arn:aws:bedrock:*:*:inference-profile/* + - arn:aws:bedrock:*:*:application-inference-profile/* Version: "2012-10-17" PolicyName: ProxyTaskRoleDefaultPolicy933321B8 Roles: @@ -222,6 +223,8 @@ Resources: Value: cohere.embed-multilingual-v3 - Name: ENABLE_CROSS_REGION_INFERENCE Value: "true" + - Name: ENABLE_APPLICATION_INFERENCE_PROFILES + Value: "true" Essential: true Image: Fn::Join: diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 9a40dcd9..d17b3002 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -38,7 +38,13 @@ Usage, UserMessage, ) -from api.setting import AWS_REGION, DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE +from api.setting import ( + AWS_REGION, + DEBUG, + DEFAULT_MODEL, + ENABLE_CROSS_REGION_INFERENCE, + ENABLE_APPLICATION_INFERENCE_PROFILES, +) logger = logging.getLogger(__name__) @@ -83,15 +89,40 @@ def list_bedrock_models() -> dict: Returns a model list combines: - ON_DEMAND models. - Cross-Region Inference Profiles (if enabled via Env) + - Application Inference Profiles (if enabled via Env) """ model_list = {} try: profile_list = [] + app_profile_dict = {} + if ENABLE_CROSS_REGION_INFERENCE: # List system defined inference profile IDs response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]] + if ENABLE_APPLICATION_INFERENCE_PROFILES: + # List application defined inference profile IDs and create mapping + response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="APPLICATION") + + for profile in response["inferenceProfileSummaries"]: + try: + profile_arn = profile.get("inferenceProfileArn") + if not profile_arn: + continue + + # Process all models in the profile + models = profile.get("models", []) + for model in models: + model_arn = model.get("modelArn", "") + if model_arn: + model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn + if model_id: + app_profile_dict[model_id] = profile_arn + except Exception as e: + logger.warning(f"Error processing application profile: {e}") + continue + # List foundation models, only cares about text outputs here. response = bedrock_client.list_foundation_models(byOutputModality="TEXT") @@ -115,6 +146,10 @@ def list_bedrock_models() -> dict: if profile_id in profile_list: model_list[profile_id] = {"modalities": input_modalities} + # Add application inference profiles + if model_id in app_profile_dict: + model_list[app_profile_dict[model_id]] = {"modalities": input_modalities} + except Exception as e: logger.error(f"Unable to list models: {str(e)}") @@ -162,7 +197,9 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): try: if stream: # Run the blocking boto3 call in a thread pool - response = await run_in_threadpool(bedrock_runtime.converse_stream, **args) + response = await run_in_threadpool( + bedrock_runtime.converse_stream, **args + ) else: # Run the blocking boto3 call in a thread pool response = await run_in_threadpool(bedrock_runtime.converse, **args) @@ -274,7 +311,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: messages.append( { "role": message.role, - "content": self._parse_content_parts(message, chat_request.model), + "content": self._parse_content_parts( + message, chat_request.model + ), } ) elif isinstance(message, AssistantMessage): @@ -283,7 +322,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: messages.append( { "role": message.role, - "content": self._parse_content_parts(message, chat_request.model), + "content": self._parse_content_parts( + message, chat_request.model + ), } ) if message.tool_calls: @@ -363,7 +404,9 @@ def _reframe_multi_payloard(self, messages: list) -> list: # If the next role is different from the previous message, add the previous role's messages to the list if next_role != current_role: if current_content: - reformatted_messages.append({"role": current_role, "content": current_content}) + reformatted_messages.append( + {"role": current_role, "content": current_content} + ) # Switch to the new role current_role = next_role current_content = [] @@ -376,7 +419,9 @@ def _reframe_multi_payloard(self, messages: list) -> list: # Add the last role's messages to the list if current_content: - reformatted_messages.append({"role": current_role, "content": current_content}) + reformatted_messages.append( + {"role": current_role, "content": current_content} + ) return reformatted_messages @@ -414,9 +459,13 @@ def _parse_request(self, chat_request: ChatRequest) -> dict: # Use max_completion_tokens if provided. max_tokens = ( - chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens + chat_request.max_completion_tokens + if chat_request.max_completion_tokens + else chat_request.max_tokens + ) + budget_tokens = self._calc_budget_tokens( + max_tokens, chat_request.reasoning_effort ) - budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort) inference_config["maxTokens"] = max_tokens # unset topP - Not supported inference_config.pop("topP") @@ -428,7 +477,9 @@ def _parse_request(self, chat_request: ChatRequest) -> dict: if chat_request.tools: tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]} - if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"): + if chat_request.tool_choice and not chat_request.model.startswith( + "meta.llama3-1-" + ): if isinstance(chat_request.tool_choice, str): # auto (default) is mapped to {"auto" : {}} # required is mapped to {"any" : {}} @@ -477,11 +528,15 @@ def _create_response( message.content = "" for c in content: if "reasoningContent" in c: - message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "") + message.reasoning_content = c["reasoningContent"][ + "reasoningText" + ].get("text", "") elif "text" in c: message.content = c["text"] else: - logger.warning("Unknown tag in message content " + ",".join(c.keys())) + logger.warning( + "Unknown tag in message content " + ",".join(c.keys()) + ) response = ChatResponse( id=message_id, @@ -505,7 +560,9 @@ def _create_response( response.created = int(time.time()) return response - def _create_response_stream(self, model_id: str, message_id: str, chunk: dict) -> ChatStreamResponse | None: + def _create_response_stream( + self, model_id: str, message_id: str, chunk: dict + ) -> ChatStreamResponse | None: """Parsing the Bedrock stream response chunk. Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples @@ -627,7 +684,9 @@ def _parse_image(self, image_url: str) -> tuple[bytes, str]: image_content = response.content return image_content, content_type else: - raise HTTPException(status_code=500, detail="Unable to access the image url") + raise HTTPException( + status_code=500, detail="Unable to access the image url" + ) def _parse_content_parts( self, @@ -687,7 +746,9 @@ def _convert_tool_spec(self, func: Function) -> dict: } } - def _calc_budget_tokens(self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]) -> int: + def _calc_budget_tokens( + self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"] + ) -> int: # Helper function to calculate budget_tokens based on the max_tokens. # Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1 # Note that The minimum budget_tokens is 1,024 tokens so far. @@ -718,7 +779,9 @@ def _convert_finish_reason(self, finish_reason: str | None) -> str | None: "complete": "stop", "content_filtered": "content_filter", } - return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower()) + return finish_reason_mapping.get( + finish_reason.lower(), finish_reason.lower() + ) return None @@ -809,7 +872,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: - response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model) + response = self._invoke_model( + args=self._parse_args(embeddings_request), model_id=embeddings_request.model + ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) @@ -825,10 +890,15 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: if isinstance(embeddings_request.input, str): input_text = embeddings_request.input - elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1: + elif ( + isinstance(embeddings_request.input, list) + and len(embeddings_request.input) == 1 + ): input_text = embeddings_request.input[0] else: - raise ValueError("Amazon Titan Embeddings models support only single strings as input.") + raise ValueError( + "Amazon Titan Embeddings models support only single strings as input." + ) args = { "inputText": input_text, # Note: inputImage is not supported! @@ -842,7 +912,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: - response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model) + response = self._invoke_model( + args=self._parse_args(embeddings_request), model_id=embeddings_request.model + ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) diff --git a/src/api/setting.py b/src/api/setting.py index e090300a..4e0a7bbd 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -16,3 +16,4 @@ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false" +ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"