1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+
1516import asyncio
1617import contextvars
1718import importlib .util
2728
2829from fastapi import FastAPI , Request
2930from fastapi .middleware .cors import CORSMiddleware
30- from pydantic import BaseModel , Field , root_validator , validator
31+ from pydantic import Field , root_validator , validator
3132from starlette .responses import StreamingResponse
3233from starlette .staticfiles import StaticFiles
3334
3435from nemoguardrails import LLMRails , RailsConfig , utils
35- from nemoguardrails .rails .llm .options import (
36- GenerationLog ,
37- GenerationOptions ,
38- GenerationResponse ,
39- )
36+ from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
4037from nemoguardrails .server .datastore .datastore import DataStore
38+ from nemoguardrails .server .schemas .openai import (
39+ Choice ,
40+ Model ,
41+ ModelsResponse ,
42+ OpenAIRequestFields ,
43+ ResponseBody ,
44+ )
4145from nemoguardrails .streaming import StreamingHandler
4246
4347logging .basicConfig (level = logging .INFO )
@@ -191,7 +195,7 @@ async def root_handler():
191195app .single_config_id = None
192196
193197
194- class RequestBody (BaseModel ):
198+ class RequestBody (OpenAIRequestFields ):
195199 config_id : Optional [str ] = Field (
196200 default = os .getenv ("DEFAULT_CONFIG_ID" , None ),
197201 description = "The id of the configuration to be used. If not set, the default configuration will be used." ,
@@ -230,47 +234,6 @@ class RequestBody(BaseModel):
230234 default = None ,
231235 description = "A state object that should be used to continue the interaction." ,
232236 )
233- # Standard OpenAI completion parameters
234- model : Optional [str ] = Field (
235- default = None ,
236- description = "The model to use for chat completion. Maps to config_id for backward compatibility." ,
237- )
238- max_tokens : Optional [int ] = Field (
239- default = None ,
240- description = "The maximum number of tokens to generate." ,
241- )
242- temperature : Optional [float ] = Field (
243- default = None ,
244- description = "Sampling temperature to use." ,
245- )
246- top_p : Optional [float ] = Field (
247- default = None ,
248- description = "Top-p sampling parameter." ,
249- )
250- stop : Optional [str ] = Field (
251- default = None ,
252- description = "Stop sequences." ,
253- )
254- presence_penalty : Optional [float ] = Field (
255- default = None ,
256- description = "Presence penalty parameter." ,
257- )
258- frequency_penalty : Optional [float ] = Field (
259- default = None ,
260- description = "Frequency penalty parameter." ,
261- )
262- function_call : Optional [dict ] = Field (
263- default = None ,
264- description = "Function call parameter." ,
265- )
266- logit_bias : Optional [dict ] = Field (
267- default = None ,
268- description = "Logit bias parameter." ,
269- )
270- log_probs : Optional [bool ] = Field (
271- default = None ,
272- description = "Log probabilities parameter." ,
273- )
274237
275238 @root_validator (pre = True )
276239 def ensure_config_id (cls , data : Any ) -> Any :
@@ -297,75 +260,6 @@ def ensure_config_ids(cls, v, values):
297260 return v
298261
299262
300- class Choice (BaseModel ):
301- index : Optional [int ] = Field (
302- default = None , description = "The index of the choice in the list of choices."
303- )
304- messages : Optional [dict ] = Field (
305- default = None , description = "The message of the choice"
306- )
307- logprobs : Optional [dict ] = Field (
308- default = None , description = "The log probabilities of the choice"
309- )
310- finish_reason : Optional [str ] = Field (
311- default = None , description = "The reason the model stopped generating tokens."
312- )
313-
314-
315- class ResponseBody (BaseModel ):
316- # OpenAI-compatible fields
317- id : Optional [str ] = Field (
318- default = None , description = "A unique identifier for the chat completion."
319- )
320- object : str = Field (
321- default = "chat.completion" ,
322- description = "The object type, which is always chat.completion" ,
323- )
324- created : Optional [int ] = Field (
325- default = None ,
326- description = "The Unix timestamp (in seconds) of when the chat completion was created." ,
327- )
328- model : Optional [str ] = Field (
329- default = None , description = "The model used for the chat completion."
330- )
331- choices : Optional [List [Choice ]] = Field (
332- default = None , description = "A list of chat completion choices."
333- )
334- # NeMo-Guardrails specific fields for backward compatibility
335- state : Optional [dict ] = Field (
336- default = None , description = "State object for continuing the conversation."
337- )
338- llm_output : Optional [dict ] = Field (
339- default = None , description = "Additional LLM output data."
340- )
341- output_data : Optional [dict ] = Field (
342- default = None , description = "Additional output data."
343- )
344- log : Optional [dict ] = Field (default = None , description = "Generation log data." )
345-
346-
347- class Model (BaseModel ):
348- id : str = Field (
349- description = "The model identifier, which can be referenced in the API endpoints."
350- )
351- object : str = Field (
352- default = "model" , description = "The object type, which is always 'model'."
353- )
354- created : int = Field (
355- description = "The Unix timestamp (in seconds) of when the model was created."
356- )
357- owned_by : str = Field (
358- default = "nemo-guardrails" , description = "The organization that owns the model."
359- )
360-
361-
362- class ModelsResponse (BaseModel ):
363- object : str = Field (
364- default = "list" , description = "The object type, which is always 'list'."
365- )
366- data : List [Model ] = Field (description = "The list of models." )
367-
368-
369263@app .get (
370264 "/v1/models" ,
371265 response_model = ModelsResponse ,
@@ -540,7 +434,7 @@ async def chat_completion(body: RequestBody, request: Request):
540434 choices = [
541435 Choice (
542436 index = 0 ,
543- messages = {
437+ message = {
544438 "content" : f"Could not load the { config_ids } guardrails configuration. "
545439 f"An internal error has occurred." ,
546440 "role" : "assistant" ,
@@ -573,7 +467,7 @@ async def chat_completion(body: RequestBody, request: Request):
573467 choices = [
574468 Choice (
575469 index = 0 ,
576- messages = {
470+ message = {
577471 "content" : "The `thread_id` must have a minimum length of 16 characters." ,
578472 "role" : "assistant" ,
579473 },
@@ -591,19 +485,25 @@ async def chat_completion(body: RequestBody, request: Request):
591485 # And prepend them.
592486 messages = thread_messages + messages
593487
594- generation_options = body .options
595- if body .max_tokens :
596- generation_options .max_tokens = body .max_tokens
597- if body .temperature is not None :
598- generation_options .temperature = body .temperature
599- if body .top_p is not None :
600- generation_options .top_p = body .top_p
601- if body .stop :
602- generation_options .stop = body .stop
603- if body .presence_penalty is not None :
604- generation_options .presence_penalty = body .presence_penalty
605- if body .frequency_penalty is not None :
606- generation_options .frequency_penalty = body .frequency_penalty
488+ generation_options = body .options
489+
490+ # Initialize llm_params if not already set
491+ if generation_options .llm_params is None :
492+ generation_options .llm_params = {}
493+
494+ # Set OpenAI-compatible parameters in llm_params
495+ if body .max_tokens :
496+ generation_options .llm_params ["max_tokens" ] = body .max_tokens
497+ if body .temperature is not None :
498+ generation_options .llm_params ["temperature" ] = body .temperature
499+ if body .top_p is not None :
500+ generation_options .llm_params ["top_p" ] = body .top_p
501+ if body .stop :
502+ generation_options .llm_params ["stop" ] = body .stop
503+ if body .presence_penalty is not None :
504+ generation_options .llm_params ["presence_penalty" ] = body .presence_penalty
505+ if body .frequency_penalty is not None :
506+ generation_options .llm_params ["frequency_penalty" ] = body .frequency_penalty
607507
608508 if (
609509 body .stream
@@ -654,7 +554,7 @@ async def chat_completion(body: RequestBody, request: Request):
654554 "choices" : [
655555 Choice (
656556 index = 0 ,
657- messages = bot_message ,
557+ message = bot_message ,
658558 finish_reason = "stop" ,
659559 logprobs = None ,
660560 )
@@ -680,7 +580,7 @@ async def chat_completion(body: RequestBody, request: Request):
680580 choices = [
681581 Choice (
682582 index = 0 ,
683- messages = {
583+ message = {
684584 "content" : "Internal server error" ,
685585 "role" : "assistant" ,
686586 },
0 commit comments