Skip to content

feat: add support to include application inference profiles as models #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ If you find this GitHub repository useful, please consider giving it a free star
- [x] Support Embedding API
- [x] Support Multimodal API
- [x] Support Cross-Region Inference
- [x] Support Application Inference Profiles (**new**)
- [x] Support Reasoning (**new**)

Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.
Expand Down Expand Up @@ -148,7 +149,48 @@ print(completion.choices[0].message.content)

Please check [Usage Guide](./docs/Usage.md) for more details about how to use embedding API, multimodal API and tool call.

### Application Inference Profiles

This proxy now supports **Application Inference Profiles**, which allow you to track usage and costs for your model invocations. You can use application inference profiles created in your AWS account for cost tracking and monitoring purposes.

**Using Application Inference Profiles:**

```bash
# Use an application inference profile ARN as the model ID
curl $OPENAI_BASE_URL/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $OPENAI_API_KEY" \
-d '{
"model": "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id",
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```

**SDK Usage with Application Inference Profiles:**

```python
from openai import OpenAI

client = OpenAI()
completion = client.chat.completions.create(
model="arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id",
messages=[{"role": "user", "content": "Hello!"}],
)

print(completion.choices[0].message.content)
```

**Benefits of Application Inference Profiles:**
- **Cost Tracking**: Track usage and costs for specific applications or use cases
- **Usage Monitoring**: Monitor model invocation metrics through CloudWatch
- **Tag-based Cost Allocation**: Use AWS cost allocation tags for detailed billing analysis

For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html).

## Other Examples

Expand Down
2 changes: 2 additions & 0 deletions deployment/BedrockProxy.template
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ Resources:
Resource:
- arn:aws:bedrock:*::foundation-model/*
- arn:aws:bedrock:*:*:inference-profile/*
- arn:aws:bedrock:*:*:application-inference-profile/*
- Action:
- secretsmanager:GetSecretValue
- secretsmanager:DescribeSecret
Expand Down Expand Up @@ -185,6 +186,7 @@ Resources:
Ref: DefaultModelId
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
ENABLE_CROSS_REGION_INFERENCE: "true"
ENABLE_APPLICATION_INFERENCE_PROFILES: "true"
MemorySize: 1024
PackageType: Image
Role:
Expand Down
3 changes: 3 additions & 0 deletions deployment/BedrockProxyFargate.template
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ Resources:
Resource:
- arn:aws:bedrock:*::foundation-model/*
- arn:aws:bedrock:*:*:inference-profile/*
- arn:aws:bedrock:*:*:application-inference-profile/*
Version: "2012-10-17"
PolicyName: ProxyTaskRoleDefaultPolicy933321B8
Roles:
Expand Down Expand Up @@ -222,6 +223,8 @@ Resources:
Value: cohere.embed-multilingual-v3
- Name: ENABLE_CROSS_REGION_INFERENCE
Value: "true"
- Name: ENABLE_APPLICATION_INFERENCE_PROFILES
Value: "true"
Essential: true
Image:
Fn::Join:
Expand Down
110 changes: 91 additions & 19 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
Usage,
UserMessage,
)
from api.setting import AWS_REGION, DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE
from api.setting import (
AWS_REGION,
DEBUG,
DEFAULT_MODEL,
ENABLE_CROSS_REGION_INFERENCE,
ENABLE_APPLICATION_INFERENCE_PROFILES,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,15 +89,40 @@ def list_bedrock_models() -> dict:
Returns a model list combines:
- ON_DEMAND models.
- Cross-Region Inference Profiles (if enabled via Env)
- Application Inference Profiles (if enabled via Env)
"""
model_list = {}
try:
profile_list = []
app_profile_dict = {}

if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]

if ENABLE_APPLICATION_INFERENCE_PROFILES:
# List application defined inference profile IDs and create mapping
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="APPLICATION")

for profile in response["inferenceProfileSummaries"]:
try:
profile_arn = profile.get("inferenceProfileArn")
if not profile_arn:
continue

# Process all models in the profile
models = profile.get("models", [])
for model in models:
model_arn = model.get("modelArn", "")
if model_arn:
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
if model_id:
app_profile_dict[model_id] = profile_arn
except Exception as e:
logger.warning(f"Error processing application profile: {e}")
continue

# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(byOutputModality="TEXT")

Expand All @@ -115,6 +146,10 @@ def list_bedrock_models() -> dict:
if profile_id in profile_list:
model_list[profile_id] = {"modalities": input_modalities}

# Add application inference profiles
if model_id in app_profile_dict:
model_list[app_profile_dict[model_id]] = {"modalities": input_modalities}

except Exception as e:
logger.error(f"Unable to list models: {str(e)}")

Expand Down Expand Up @@ -162,7 +197,9 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
try:
if stream:
# Run the blocking boto3 call in a thread pool
response = await run_in_threadpool(bedrock_runtime.converse_stream, **args)
response = await run_in_threadpool(
bedrock_runtime.converse_stream, **args
)
else:
# Run the blocking boto3 call in a thread pool
response = await run_in_threadpool(bedrock_runtime.converse, **args)
Expand Down Expand Up @@ -274,7 +311,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
messages.append(
{
"role": message.role,
"content": self._parse_content_parts(message, chat_request.model),
"content": self._parse_content_parts(
message, chat_request.model
),
}
)
elif isinstance(message, AssistantMessage):
Expand All @@ -283,7 +322,9 @@ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
messages.append(
{
"role": message.role,
"content": self._parse_content_parts(message, chat_request.model),
"content": self._parse_content_parts(
message, chat_request.model
),
}
)
if message.tool_calls:
Expand Down Expand Up @@ -363,7 +404,9 @@ def _reframe_multi_payloard(self, messages: list) -> list:
# If the next role is different from the previous message, add the previous role's messages to the list
if next_role != current_role:
if current_content:
reformatted_messages.append({"role": current_role, "content": current_content})
reformatted_messages.append(
{"role": current_role, "content": current_content}
)
# Switch to the new role
current_role = next_role
current_content = []
Expand All @@ -376,7 +419,9 @@ def _reframe_multi_payloard(self, messages: list) -> list:

# Add the last role's messages to the list
if current_content:
reformatted_messages.append({"role": current_role, "content": current_content})
reformatted_messages.append(
{"role": current_role, "content": current_content}
)

return reformatted_messages

Expand Down Expand Up @@ -414,9 +459,13 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
# Use max_completion_tokens if provided.

max_tokens = (
chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
chat_request.max_completion_tokens
if chat_request.max_completion_tokens
else chat_request.max_tokens
)
budget_tokens = self._calc_budget_tokens(
max_tokens, chat_request.reasoning_effort
)
budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort)
inference_config["maxTokens"] = max_tokens
# unset topP - Not supported
inference_config.pop("topP")
Expand All @@ -428,7 +477,9 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
if chat_request.tools:
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}

if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"):
if chat_request.tool_choice and not chat_request.model.startswith(
"meta.llama3-1-"
):
if isinstance(chat_request.tool_choice, str):
# auto (default) is mapped to {"auto" : {}}
# required is mapped to {"any" : {}}
Expand Down Expand Up @@ -477,11 +528,15 @@ def _create_response(
message.content = ""
for c in content:
if "reasoningContent" in c:
message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "")
message.reasoning_content = c["reasoningContent"][
"reasoningText"
].get("text", "")
elif "text" in c:
message.content = c["text"]
else:
logger.warning("Unknown tag in message content " + ",".join(c.keys()))
logger.warning(
"Unknown tag in message content " + ",".join(c.keys())
)

response = ChatResponse(
id=message_id,
Expand All @@ -505,7 +560,9 @@ def _create_response(
response.created = int(time.time())
return response

def _create_response_stream(self, model_id: str, message_id: str, chunk: dict) -> ChatStreamResponse | None:
def _create_response_stream(
self, model_id: str, message_id: str, chunk: dict
) -> ChatStreamResponse | None:
"""Parsing the Bedrock stream response chunk.

Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
Expand Down Expand Up @@ -627,7 +684,9 @@ def _parse_image(self, image_url: str) -> tuple[bytes, str]:
image_content = response.content
return image_content, content_type
else:
raise HTTPException(status_code=500, detail="Unable to access the image url")
raise HTTPException(
status_code=500, detail="Unable to access the image url"
)

def _parse_content_parts(
self,
Expand Down Expand Up @@ -687,7 +746,9 @@ def _convert_tool_spec(self, func: Function) -> dict:
}
}

def _calc_budget_tokens(self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]) -> int:
def _calc_budget_tokens(
self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]
) -> int:
# Helper function to calculate budget_tokens based on the max_tokens.
# Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1
# Note that The minimum budget_tokens is 1,024 tokens so far.
Expand Down Expand Up @@ -718,7 +779,9 @@ def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
"complete": "stop",
"content_filtered": "content_filter",
}
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
return finish_reason_mapping.get(
finish_reason.lower(), finish_reason.lower()
)
return None


Expand Down Expand Up @@ -809,7 +872,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
return args

def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
Expand All @@ -825,10 +890,15 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
if isinstance(embeddings_request.input, str):
input_text = embeddings_request.input
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
elif (
isinstance(embeddings_request.input, list)
and len(embeddings_request.input) == 1
):
input_text = embeddings_request.input[0]
else:
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
raise ValueError(
"Amazon Titan Embeddings models support only single strings as input."
)
args = {
"inputText": input_text,
# Note: inputImage is not supported!
Expand All @@ -842,7 +912,9 @@ def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
return args

def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
Expand Down
1 change: 1 addition & 0 deletions src/api/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"