@@ -16,6 +16,10 @@ def list_models():
16
16
if bedrock_models :
17
17
models .extend (bedrock_models )
18
18
19
+ bedrock_cris_models = list_bedrock_cris_models ()
20
+ if bedrock_cris_models :
21
+ models .extend (bedrock_cris_models )
22
+
19
23
fine_tuned_models = list_bedrock_finetuned_models ()
20
24
if fine_tuned_models :
21
25
models .extend (fine_tuned_models )
@@ -80,6 +84,73 @@ def list_azure_openai_models():
80
84
]
81
85
82
86
87
+ # Based on the table (Need to support both document and sytem prompt)
88
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html
89
+ def does_model_support_documents (model_name ):
90
+ return (
91
+ not re .match (r"^ai21.jamba*" , model_name )
92
+ and not re .match (r"^ai21.j2*" , model_name )
93
+ and not re .match (r"^amazon.titan-t*" , model_name )
94
+ and not re .match (r"^cohere.command-light*" , model_name )
95
+ and not re .match (r"^cohere.command-text*" , model_name )
96
+ and not re .match (r"^mistral.mistral-7b-instruct-*" , model_name )
97
+ and not re .match (r"^mistral.mistral-small*" , model_name )
98
+ and not re .match (r"^amazon.nova-reel*" , model_name )
99
+ and not re .match (r"^amazon.nova-canvas*" , model_name )
100
+ and not re .match (r"^amazon.nova-micro*" , model_name )
101
+ )
102
+
103
+
104
+ def create_bedrock_model_profile (bedrock_model : dict , model_name : str ) -> dict :
105
+ model = {
106
+ "provider" : Provider .BEDROCK .value ,
107
+ "name" : model_name ,
108
+ "streaming" : bedrock_model .get ("responseStreamingSupported" , False ),
109
+ "inputModalities" : bedrock_model ["inputModalities" ],
110
+ "outputModalities" : bedrock_model ["outputModalities" ],
111
+ "interface" : ModelInterface .LANGCHAIN .value ,
112
+ "ragSupported" : True ,
113
+ "bedrockGuardrails" : True ,
114
+ }
115
+
116
+ if does_model_support_documents (model ["name" ]):
117
+ model ["inputModalities" ].append ("DOCUMENT" )
118
+ return model
119
+
120
+
121
+ def list_cross_region_inference_profiles ():
122
+ bedrock = genai_core .clients .get_bedrock_client (service_name = "bedrock" )
123
+ response = bedrock .list_inference_profiles ()
124
+
125
+ return {
126
+ inference_profile ["models" ][0 ]["modelArn" ].split ("/" )[1 ]: inference_profile [
127
+ "inferenceProfileId"
128
+ ]
129
+ for inference_profile in response .get ("inferenceProfileSummaries" , [])
130
+ if (
131
+ inference_profile .get ("status" ) == "ACTIVE"
132
+ and inference_profile .get ("type" ) == "SYSTEM_DEFINED"
133
+ )
134
+ }
135
+
136
+
137
+ def list_bedrock_cris_models ():
138
+ try :
139
+ cross_region_profiles = list_cross_region_inference_profiles ()
140
+ bedrock_client = genai_core .clients .get_bedrock_client (service_name = "bedrock" )
141
+ all_models = bedrock_client .list_foundation_models ()["modelSummaries" ]
142
+
143
+ return [
144
+ create_bedrock_model_profile (model , cross_region_profiles [model ["modelId" ]])
145
+ for model in all_models
146
+ if genai_core .types .InferenceType .INFERENCE_PROFILE .value
147
+ in model ["inferenceTypesSupported" ]
148
+ ]
149
+ except Exception as e :
150
+ logger .error (f"Error listing cross region inference profiles models: { e } " )
151
+ return None
152
+
153
+
83
154
def list_bedrock_models ():
84
155
try :
85
156
bedrock = genai_core .clients .get_bedrock_client (service_name = "bedrock" )
@@ -108,33 +179,9 @@ def list_bedrock_models():
108
179
)
109
180
):
110
181
continue
111
- model = {
112
- "provider" : Provider .BEDROCK .value ,
113
- "name" : bedrock_model ["modelId" ],
114
- "streaming" : bedrock_model .get ("responseStreamingSupported" , False ),
115
- "inputModalities" : bedrock_model ["inputModalities" ],
116
- "outputModalities" : bedrock_model ["outputModalities" ],
117
- "interface" : ModelInterface .LANGCHAIN .value ,
118
- "ragSupported" : True ,
119
- "bedrockGuardrails" : True ,
120
- }
121
- # Based on the table (Need to support both document and sytem prompt)
122
- # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html
123
- if (
124
- not re .match (r"^ai21.jamba*" , model ["name" ])
125
- and not re .match (r"^ai21.j2*" , model ["name" ])
126
- and not re .match (r"^amazon.titan-t*" , model ["name" ])
127
- and not re .match (r"^cohere.command-light*" , model ["name" ])
128
- and not re .match (r"^cohere.command-text*" , model ["name" ])
129
- and not re .match (r"^mistral.mistral-7b-instruct-*" , model ["name" ])
130
- and not re .match (r"^mistral.mistral-small*" , model ["name" ])
131
- and not re .match (r"^amazon.nova-reel*" , model ["name" ])
132
- and not re .match (r"^amazon.nova-canvas*" , model ["name" ])
133
- and not re .match (r"^amazon.nova-micro*" , model ["name" ])
134
- ):
135
- model ["inputModalities" ].append ("DOCUMENT" )
136
-
137
- models .append (model )
182
+ models .append (
183
+ create_bedrock_model_profile (bedrock_model , bedrock_model ["modelId" ])
184
+ )
138
185
139
186
return models
140
187
except Exception as e :
0 commit comments