Skip to content

Commit 59c1644

Browse files
chore: Move OpenAPI schema and fix typos
1 parent 1266ab1 commit 59c1644

File tree

4 files changed

+184
-141
lines changed

4 files changed

+184
-141
lines changed

nemoguardrails/server/api.py

Lines changed: 34 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
1516
import asyncio
1617
import contextvars
1718
import importlib.util
@@ -27,17 +28,20 @@
2728

2829
from fastapi import FastAPI, Request
2930
from fastapi.middleware.cors import CORSMiddleware
30-
from pydantic import BaseModel, Field, root_validator, validator
31+
from pydantic import Field, root_validator, validator
3132
from starlette.responses import StreamingResponse
3233
from starlette.staticfiles import StaticFiles
3334

3435
from nemoguardrails import LLMRails, RailsConfig, utils
35-
from nemoguardrails.rails.llm.options import (
36-
GenerationLog,
37-
GenerationOptions,
38-
GenerationResponse,
39-
)
36+
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
4037
from nemoguardrails.server.datastore.datastore import DataStore
38+
from nemoguardrails.server.schemas.openai import (
39+
Choice,
40+
Model,
41+
ModelsResponse,
42+
OpenAIRequestFields,
43+
ResponseBody,
44+
)
4145
from nemoguardrails.streaming import StreamingHandler
4246

4347
logging.basicConfig(level=logging.INFO)
@@ -191,7 +195,7 @@ async def root_handler():
191195
app.single_config_id = None
192196

193197

194-
class RequestBody(BaseModel):
198+
class RequestBody(OpenAIRequestFields):
195199
config_id: Optional[str] = Field(
196200
default=os.getenv("DEFAULT_CONFIG_ID", None),
197201
description="The id of the configuration to be used. If not set, the default configuration will be used.",
@@ -230,47 +234,6 @@ class RequestBody(BaseModel):
230234
default=None,
231235
description="A state object that should be used to continue the interaction.",
232236
)
233-
# Standard OpenAI completion parameters
234-
model: Optional[str] = Field(
235-
default=None,
236-
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
237-
)
238-
max_tokens: Optional[int] = Field(
239-
default=None,
240-
description="The maximum number of tokens to generate.",
241-
)
242-
temperature: Optional[float] = Field(
243-
default=None,
244-
description="Sampling temperature to use.",
245-
)
246-
top_p: Optional[float] = Field(
247-
default=None,
248-
description="Top-p sampling parameter.",
249-
)
250-
stop: Optional[str] = Field(
251-
default=None,
252-
description="Stop sequences.",
253-
)
254-
presence_penalty: Optional[float] = Field(
255-
default=None,
256-
description="Presence penalty parameter.",
257-
)
258-
frequency_penalty: Optional[float] = Field(
259-
default=None,
260-
description="Frequency penalty parameter.",
261-
)
262-
function_call: Optional[dict] = Field(
263-
default=None,
264-
description="Function call parameter.",
265-
)
266-
logit_bias: Optional[dict] = Field(
267-
default=None,
268-
description="Logit bias parameter.",
269-
)
270-
log_probs: Optional[bool] = Field(
271-
default=None,
272-
description="Log probabilities parameter.",
273-
)
274237

275238
@root_validator(pre=True)
276239
def ensure_config_id(cls, data: Any) -> Any:
@@ -297,75 +260,6 @@ def ensure_config_ids(cls, v, values):
297260
return v
298261

299262

300-
class Choice(BaseModel):
301-
index: Optional[int] = Field(
302-
default=None, description="The index of the choice in the list of choices."
303-
)
304-
messages: Optional[dict] = Field(
305-
default=None, description="The message of the choice"
306-
)
307-
logprobs: Optional[dict] = Field(
308-
default=None, description="The log probabilities of the choice"
309-
)
310-
finish_reason: Optional[str] = Field(
311-
default=None, description="The reason the model stopped generating tokens."
312-
)
313-
314-
315-
class ResponseBody(BaseModel):
316-
# OpenAI-compatible fields
317-
id: Optional[str] = Field(
318-
default=None, description="A unique identifier for the chat completion."
319-
)
320-
object: str = Field(
321-
default="chat.completion",
322-
description="The object type, which is always chat.completion",
323-
)
324-
created: Optional[int] = Field(
325-
default=None,
326-
description="The Unix timestamp (in seconds) of when the chat completion was created.",
327-
)
328-
model: Optional[str] = Field(
329-
default=None, description="The model used for the chat completion."
330-
)
331-
choices: Optional[List[Choice]] = Field(
332-
default=None, description="A list of chat completion choices."
333-
)
334-
# NeMo-Guardrails specific fields for backward compatibility
335-
state: Optional[dict] = Field(
336-
default=None, description="State object for continuing the conversation."
337-
)
338-
llm_output: Optional[dict] = Field(
339-
default=None, description="Additional LLM output data."
340-
)
341-
output_data: Optional[dict] = Field(
342-
default=None, description="Additional output data."
343-
)
344-
log: Optional[dict] = Field(default=None, description="Generation log data.")
345-
346-
347-
class Model(BaseModel):
348-
id: str = Field(
349-
description="The model identifier, which can be referenced in the API endpoints."
350-
)
351-
object: str = Field(
352-
default="model", description="The object type, which is always 'model'."
353-
)
354-
created: int = Field(
355-
description="The Unix timestamp (in seconds) of when the model was created."
356-
)
357-
owned_by: str = Field(
358-
default="nemo-guardrails", description="The organization that owns the model."
359-
)
360-
361-
362-
class ModelsResponse(BaseModel):
363-
object: str = Field(
364-
default="list", description="The object type, which is always 'list'."
365-
)
366-
data: List[Model] = Field(description="The list of models.")
367-
368-
369263
@app.get(
370264
"/v1/models",
371265
response_model=ModelsResponse,
@@ -540,7 +434,7 @@ async def chat_completion(body: RequestBody, request: Request):
540434
choices=[
541435
Choice(
542436
index=0,
543-
messages={
437+
message={
544438
"content": f"Could not load the {config_ids} guardrails configuration. "
545439
f"An internal error has occurred.",
546440
"role": "assistant",
@@ -573,7 +467,7 @@ async def chat_completion(body: RequestBody, request: Request):
573467
choices=[
574468
Choice(
575469
index=0,
576-
messages={
470+
message={
577471
"content": "The `thread_id` must have a minimum length of 16 characters.",
578472
"role": "assistant",
579473
},
@@ -591,19 +485,25 @@ async def chat_completion(body: RequestBody, request: Request):
591485
# And prepend them.
592486
messages = thread_messages + messages
593487

594-
generation_options = body.options
595-
if body.max_tokens:
596-
generation_options.max_tokens = body.max_tokens
597-
if body.temperature is not None:
598-
generation_options.temperature = body.temperature
599-
if body.top_p is not None:
600-
generation_options.top_p = body.top_p
601-
if body.stop:
602-
generation_options.stop = body.stop
603-
if body.presence_penalty is not None:
604-
generation_options.presence_penalty = body.presence_penalty
605-
if body.frequency_penalty is not None:
606-
generation_options.frequency_penalty = body.frequency_penalty
488+
generation_options = body.options
489+
490+
# Initialize llm_params if not already set
491+
if generation_options.llm_params is None:
492+
generation_options.llm_params = {}
493+
494+
# Set OpenAI-compatible parameters in llm_params
495+
if body.max_tokens:
496+
generation_options.llm_params["max_tokens"] = body.max_tokens
497+
if body.temperature is not None:
498+
generation_options.llm_params["temperature"] = body.temperature
499+
if body.top_p is not None:
500+
generation_options.llm_params["top_p"] = body.top_p
501+
if body.stop:
502+
generation_options.llm_params["stop"] = body.stop
503+
if body.presence_penalty is not None:
504+
generation_options.llm_params["presence_penalty"] = body.presence_penalty
505+
if body.frequency_penalty is not None:
506+
generation_options.llm_params["frequency_penalty"] = body.frequency_penalty
607507

608508
if (
609509
body.stream
@@ -654,7 +554,7 @@ async def chat_completion(body: RequestBody, request: Request):
654554
"choices": [
655555
Choice(
656556
index=0,
657-
messages=bot_message,
557+
message=bot_message,
658558
finish_reason="stop",
659559
logprobs=None,
660560
)
@@ -680,7 +580,7 @@ async def chat_completion(body: RequestBody, request: Request):
680580
choices=[
681581
Choice(
682582
index=0,
683-
messages={
583+
message={
684584
"content": "Internal server error",
685585
"role": "assistant",
686586
},
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""OpenAI API schema definitions for the NeMo Guardrails server."""
17+
18+
from typing import List, Optional, Union
19+
20+
from pydantic import BaseModel, Field
21+
22+
23+
class OpenAIRequestFields(BaseModel):
24+
"""OpenAI API request fields that can be mixed into other request schemas."""
25+
26+
# Standard OpenAI completion parameters
27+
model: Optional[str] = Field(
28+
default=None,
29+
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
30+
)
31+
max_tokens: Optional[int] = Field(
32+
default=None,
33+
description="The maximum number of tokens to generate.",
34+
)
35+
temperature: Optional[float] = Field(
36+
default=None,
37+
description="Sampling temperature to use.",
38+
)
39+
top_p: Optional[float] = Field(
40+
default=None,
41+
description="Top-p sampling parameter.",
42+
)
43+
stop: Optional[Union[str, List[str]]] = Field(
44+
default=None,
45+
description="Stop sequences.",
46+
)
47+
presence_penalty: Optional[float] = Field(
48+
default=None,
49+
description="Presence penalty parameter.",
50+
)
51+
frequency_penalty: Optional[float] = Field(
52+
default=None,
53+
description="Frequency penalty parameter.",
54+
)
55+
function_call: Optional[dict] = Field(
56+
default=None,
57+
description="Function call parameter.",
58+
)
59+
logit_bias: Optional[dict] = Field(
60+
default=None,
61+
description="Logit bias parameter.",
62+
)
63+
log_probs: Optional[bool] = Field(
64+
default=None,
65+
description="Log probabilities parameter.",
66+
)
67+
68+
69+
class Choice(BaseModel):
70+
"""OpenAI API choice structure in chat completion responses."""
71+
72+
index: Optional[int] = Field(
73+
default=None, description="The index of the choice in the list of choices."
74+
)
75+
message: Optional[dict] = Field(
76+
default=None, description="The message of the choice"
77+
)
78+
logprobs: Optional[dict] = Field(
79+
default=None, description="The log probabilities of the choice"
80+
)
81+
finish_reason: Optional[str] = Field(
82+
default=None, description="The reason the model stopped generating tokens."
83+
)
84+
85+
86+
class ResponseBody(BaseModel):
87+
"""OpenAI API response body with NeMo-Guardrails extensions."""
88+
89+
# OpenAI API fields
90+
id: Optional[str] = Field(
91+
default=None, description="A unique identifier for the chat completion."
92+
)
93+
object: str = Field(
94+
default="chat.completion",
95+
description="The object type, which is always chat.completion",
96+
)
97+
created: Optional[int] = Field(
98+
default=None,
99+
description="The Unix timestamp (in seconds) of when the chat completion was created.",
100+
)
101+
model: Optional[str] = Field(
102+
default=None, description="The model used for the chat completion."
103+
)
104+
choices: Optional[List[Choice]] = Field(
105+
default=None, description="A list of chat completion choices."
106+
)
107+
# NeMo-Guardrails specific fields for backward compatibility
108+
state: Optional[dict] = Field(
109+
default=None, description="State object for continuing the conversation."
110+
)
111+
llm_output: Optional[dict] = Field(
112+
default=None, description="Additional LLM output data."
113+
)
114+
output_data: Optional[dict] = Field(
115+
default=None, description="Additional output data."
116+
)
117+
log: Optional[dict] = Field(default=None, description="Generation log data.")
118+
119+
120+
class Model(BaseModel):
121+
"""OpenAI API model representation."""
122+
123+
id: str = Field(
124+
description="The model identifier, which can be referenced in the API endpoints."
125+
)
126+
object: str = Field(
127+
default="model", description="The object type, which is always 'model'."
128+
)
129+
created: int = Field(
130+
description="The Unix timestamp (in seconds) of when the model was created."
131+
)
132+
owned_by: str = Field(
133+
default="nemo-guardrails", description="The organization that owns the model."
134+
)
135+
136+
137+
class ModelsResponse(BaseModel):
138+
"""OpenAI API models list response."""
139+
140+
object: str = Field(
141+
default="list", description="The object type, which is always 'list'."
142+
)
143+
data: List[Model] = Field(description="The list of models.")

0 commit comments

Comments
 (0)