Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions nemoguardrails/colang/v2_x/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ColangSyntaxError,
)
from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus
from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state
from nemoguardrails.colang.v2_x.runtime.statemachine import (
FlowConfig,
InternalEvent,
Expand Down Expand Up @@ -439,10 +440,13 @@ async def process_events(
)
initialize_state(state)
elif isinstance(state, dict):
# TODO: Implement dict to State conversion
raise NotImplementedError()
# if isinstance(state, dict):
# state = State.from_dict(state)
# Convert dict to State object
if state.get("version") == "2.x" and "state" in state:
# Handle the serialized state format from API calls
state = json_to_state(state["state"])
else:
# TODO: Implement other dict to State conversion formats if needed
raise NotImplementedError("Unsupported state dict format")

assert isinstance(state, State)
assert state.main_flow_state is not None
Expand Down
188 changes: 139 additions & 49 deletions nemoguardrails/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import contextvars
import importlib.util
Expand All @@ -20,23 +21,27 @@
import os.path
import re
import time
import uuid
import warnings
from contextlib import asynccontextmanager
from typing import Any, Callable, List, Optional

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, root_validator, validator
from pydantic import Field, root_validator, validator
from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles

from nemoguardrails import LLMRails, RailsConfig, utils
from nemoguardrails.rails.llm.options import (
GenerationLog,
GenerationOptions,
GenerationResponse,
)
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
from nemoguardrails.server.datastore.datastore import DataStore
from nemoguardrails.server.schemas.openai import (
Choice,
Model,
ModelsResponse,
OpenAIRequestFields,
ResponseBody,
)
from nemoguardrails.streaming import StreamingHandler

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -190,7 +195,7 @@ async def root_handler():
app.single_config_id = None


class RequestBody(BaseModel):
class RequestBody(OpenAIRequestFields):
config_id: Optional[str] = Field(
default=os.getenv("DEFAULT_CONFIG_ID", None),
description="The id of the configuration to be used. If not set, the default configuration will be used.",
Expand Down Expand Up @@ -233,6 +238,8 @@ class RequestBody(BaseModel):
@root_validator(pre=True)
def ensure_config_id(cls, data: Any) -> Any:
if isinstance(data, dict):
if data.get("model") is not None and data.get("config_id") is None:
data["config_id"] = data["model"]
if data.get("config_id") is not None and data.get("config_ids") is not None:
raise ValueError(
"Only one of config_id or config_ids should be specified"
Expand All @@ -253,25 +260,44 @@ def ensure_config_ids(cls, v, values):
return v


class ResponseBody(BaseModel):
messages: Optional[List[dict]] = Field(
default=None, description="The new messages in the conversation"
)
llm_output: Optional[dict] = Field(
default=None,
description="Contains any additional output coming from the LLM.",
)
output_data: Optional[dict] = Field(
default=None,
description="The output data, i.e. a dict with the values corresponding to the `output_vars`.",
)
log: Optional[GenerationLog] = Field(
default=None, description="Additional logging information."
)
state: Optional[dict] = Field(
default=None,
description="A state object that should be used to continue the interaction in the future.",
)
@app.get(
"/v1/models",
response_model=ModelsResponse,
summary="List available models",
description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.",
)
async def get_models():
"""Returns the list of available models (guardrails configurations) in OpenAI-compatible format."""

# Use the same logic as get_rails_configs to find available configurations
if app.single_config_mode:
config_ids = [app.single_config_id] if app.single_config_id else []
else:
config_ids = [
f
for f in os.listdir(app.rails_config_path)
if os.path.isdir(os.path.join(app.rails_config_path, f))
and f[0] != "."
and f[0] != "_"
# Filter out all the configs for which there is no `config.yml` file.
and (
os.path.exists(os.path.join(app.rails_config_path, f, "config.yml"))
or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml"))
)
]

# Convert configurations to OpenAI model format
models = []
for config_id in config_ids:
model = Model(
id=config_id,
object="model",
created=int(time.time()), # Use current time as created timestamp
owned_by="nemo-guardrails",
)
models.append(model)

return ModelsResponse(data=models)


@app.get(
Expand Down Expand Up @@ -401,13 +427,22 @@ async def chat_completion(body: RequestBody, request: Request):
except ValueError as ex:
log.exception(ex)
return ResponseBody(
messages=[
{
"role": "assistant",
"content": f"Could not load the {config_ids} guardrails configuration. "
f"An internal error has occurred.",
}
]
id=f"chatcmpl-{uuid.uuid4()}",
object="chat.completion",
created=int(time.time()),
model=config_ids[0] if config_ids else None,
choices=[
Choice(
index=0,
message={
"content": f"Could not load the {config_ids} guardrails configuration. "
f"An internal error has occurred.",
"role": "assistant",
},
finish_reason="error",
logprobs=None,
)
],
)

try:
Expand All @@ -425,12 +460,21 @@ async def chat_completion(body: RequestBody, request: Request):
# We make sure the `thread_id` meets the minimum complexity requirement.
if len(body.thread_id) < 16:
return ResponseBody(
messages=[
{
"role": "assistant",
"content": "The `thread_id` must have a minimum length of 16 characters.",
}
]
id=f"chatcmpl-{uuid.uuid4()}",
object="chat.completion",
created=int(time.time()),
model=None,
choices=[
Choice(
index=0,
message={
"content": "The `thread_id` must have a minimum length of 16 characters.",
"role": "assistant",
},
finish_reason="error",
logprobs=None,
)
],
)

# Fetch the existing thread messages. For easier management, we prepend
Expand All @@ -441,6 +485,26 @@ async def chat_completion(body: RequestBody, request: Request):
# And prepend them.
messages = thread_messages + messages

generation_options = body.options

# Initialize llm_params if not already set
if generation_options.llm_params is None:
generation_options.llm_params = {}

# Set OpenAI-compatible parameters in llm_params
if body.max_tokens:
generation_options.llm_params["max_tokens"] = body.max_tokens
if body.temperature is not None:
generation_options.llm_params["temperature"] = body.temperature
if body.top_p is not None:
generation_options.llm_params["top_p"] = body.top_p
if body.stop:
generation_options.llm_params["stop"] = body.stop
if body.presence_penalty is not None:
generation_options.llm_params["presence_penalty"] = body.presence_penalty
if body.frequency_penalty is not None:
generation_options.llm_params["frequency_penalty"] = body.frequency_penalty

if (
body.stream
and llm_rails.config.streaming_supported
Expand All @@ -459,8 +523,6 @@ async def chat_completion(body: RequestBody, request: Request):
)
)

# TODO: Add support for thread_ids in streaming mode

return StreamingResponse(streaming_handler)
else:
res = await llm_rails.generate_async(
Expand All @@ -483,21 +545,49 @@ async def chat_completion(body: RequestBody, request: Request):
if body.thread_id and datastore is not None and datastore_key is not None:
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))

result = ResponseBody(messages=[bot_message])
# Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
response_kwargs = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": config_ids[0] if config_ids else None,
"choices": [
Choice(
index=0,
message=bot_message,
finish_reason="stop",
logprobs=None,
)
],
}

# If we have additional GenerationResponse fields, we return as well
# If we have additional GenerationResponse fields, include them for backward compatibility
if isinstance(res, GenerationResponse):
result.llm_output = res.llm_output
result.output_data = res.output_data
result.log = res.log
result.state = res.state
response_kwargs["llm_output"] = res.llm_output
response_kwargs["output_data"] = res.output_data
response_kwargs["log"] = res.log
response_kwargs["state"] = res.state

return result
return ResponseBody(**response_kwargs)

except Exception as ex:
log.exception(ex)
return ResponseBody(
messages=[{"role": "assistant", "content": "Internal server error."}]
id=f"chatcmpl-{uuid.uuid4()}",
object="chat.completion",
created=int(time.time()),
model=None,
choices=[
Choice(
index=0,
message={
"content": "Internal server error",
"role": "assistant",
},
finish_reason="error",
logprobs=None,
)
],
)


Expand Down
Loading