Skip to content

feat: Better support for tools_strict=True when using the OpenAIChatGenerator #9382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import os
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -371,6 +372,74 @@ async def run_async(

return {"replies": completions}

@staticmethod
def _is_type_object(obj: Any) -> bool:
"""
Check if the object is of type 'object' in OpenAI's schema.

:param obj: The object to check.
:returns: True if the object is of type 'object', False otherwise.
"""
return isinstance(obj, dict) and "type" in obj and obj["type"] == "object"

def _strictify_object(self, obj: Any) -> Any:
"""
Recursively updates the sub-objects of the tool specification to follow OpenAI's strict schema.

This function:
- Sets "additionalProperties" to False in all type = object sections of the tool specification, which is a
requirement for OpenAI's strict schema.
- Removes all non-required fields since all property fields must be required. For ease, we opt to remove all
variables that are not required.
"""
if isinstance(obj, dict):
for key, value in list(obj.items()):
# type = object updates
if self._is_type_object(value):
if "required" not in value:
# If type = object and doesn't have required variables it needs to be removed
del obj[key]
continue

# If type = object and has required variables, we need to remove all non-required variables
# from the properties
obj[key]["properties"] = {
k: self._strictify_object(v) for k, v in value["properties"].items() if k in value["required"]
}
# Always add and set additionalProperties to False for type = object
obj[key]["additionalProperties"] = False
continue

obj[key] = self._strictify_object(value)
return obj

if isinstance(obj, list):
new_items = []
for item in obj:
# If type = object and doesn't have required variables it needs to be removed
if self._is_type_object(item) and "required" not in item:
continue
new_items.append(self._strictify_object(item))
return new_items

return obj

def _strictify_function_schema(self, function_schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Updates the tool specification object to follow OpenAI's strict schema.

OpenAI's strict schema is equivalent to their Structured Output schema.
More information on Structured Output can be found
(here)[https://platform.openai.com/docs/guides/structured-outputs/supported-schemas?api-mode=responses].

The supported schemas for Structured Outputs can be found
(here)[https://platform.openai.com/docs/guides/structured-outputs/supported-schemas?api-mode=responses#supported-schemas]
"""
# TODO Ideally _strictify would also "repair" any function schema. e.g. a required variable had to be
# removed b/c it had schema {"type": "object", "additionalProperties": True} which is not allowed
# Look at test_strictify_function_schema_chat_message for a real example that is quite messy.
return {**self._strictify_object(deepcopy(function_schema)), **{"strict": True}}

def _prepare_api_call( # noqa: PLR0913
self,
*,
Expand Down Expand Up @@ -398,8 +467,7 @@ def _prepare_api_call( # noqa: PLR0913
for t in tools:
function_spec = {**t.tool_spec}
if tools_strict:
function_spec["strict"] = True
function_spec["parameters"]["additionalProperties"] = False
function_spec = self._strictify_function_schema(function_spec)
tool_definitions.append({"type": "function", "function": function_spec})
openai_tools = {"tools": tool_definitions}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Added _strictify function to enforce OpenAI's strict schema for tool specifications.
This update affects the OpenAIChatGenerator component when tools_strict=True is enabled.
The function ensures compatibility by setting additionalProperties to False, converting oneOf to anyOf, and removing all non-required fields as required by OpenAI's Structured Output schema.
154 changes: 154 additions & 0 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.tools.toolset import Toolset

from test.tools.test_parameters_schema_utils import (
CHAT_MESSAGE_SCHEMA,
CHAT_ROLE_SCHEMA,
TEXT_CONTENT_SCHEMA,
TOOL_CALL_SCHEMA,
TOOL_CALL_RESULT_SCHEMA,
)


@pytest.fixture
def chat_messages():
Expand Down Expand Up @@ -908,6 +916,130 @@ def test_convert_usage_chunk_to_streaming_chunk(self):
assert result.meta["model"] == "gpt-4o-mini-2024-07-18"
assert result.meta["received_at"] is not None

def test_strictify_function_schema(self):
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
function_spec = {
"name": "function_name",
"description": "function_description",
"parameters": {
"type": "object",
"properties": {
"param1": {"type": "string"},
"param2": {"type": "integer"},
"param3": {"type": "array", "items": {"type": "string"}},
},
"required": ["param1", "param2"],
},
}
strict_function_spec = component._strictify_function_schema(function_spec)
assert strict_function_spec == {
"name": "function_name",
"description": "function_description",
"parameters": {
"type": "object",
"properties": {"param1": {"type": "string"}, "param2": {"type": "integer"}},
"required": ["param1", "param2"],
"additionalProperties": False,
},
"strict": True,
}

def test_strictify_function_schema_chat_message(self):
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
example_schema = {
"$defs": {
"ChatMessage": CHAT_MESSAGE_SCHEMA,
"ChatRole": CHAT_ROLE_SCHEMA,
"TextContent": TEXT_CONTENT_SCHEMA,
"ToolCall": TOOL_CALL_SCHEMA,
"ToolCallResult": TOOL_CALL_RESULT_SCHEMA,
},
"description": "A test function",
"properties": {
"input_name": {
"description": "A list of chat messages",
"items": {"$ref": "#/$defs/ChatMessage"},
"type": "array",
}
},
"required": ["input_name"],
"type": "object",
}
strict_function_spec = component._strictify_function_schema(example_schema)
expected_spec = {
"$defs": {
"ChatMessage": {
"type": "object",
"properties": {
"role": {"$ref": "#/$defs/ChatRole", "description": "Field 'role' of 'ChatMessage'."},
"content": {
"type": "array",
"description": "Field 'content' of 'ChatMessage'.",
"items": {
"anyOf": [
{"$ref": "#/$defs/TextContent"}
# {"$ref": "#/$defs/ToolCall"},
# {"$ref": "#/$defs/ToolCallResult"},
]
},
},
},
"required": ["role", "content"],
},
"ChatRole": CHAT_ROLE_SCHEMA,
"TextContent": TEXT_CONTENT_SCHEMA,
# TODO `arguments` is the problematic parameter that will be auto-removed but it's also in required
# This means ToolCall itself should be removed (if possible). It can be from ChatMessage
# since it's contained within an anyOf list.
# However, this also affects TOOL_CALL_RESULT_SCHEMA which has ToolCall as a requirement
# This means ToolCallResult should also be removed if possible.
# Then if not possible to remove all the way up then an error should be thrown.
# "ToolCall": {
# "type": "object",
# "properties": {
# "tool_name": {"type": "string", "description": "The name of the Tool to call."},
# "arguments": {
# "type": "object",
# "description": "The arguments to call the Tool with.",
# "additionalProperties": True,
# },
# "id": {
# "anyOf": [{"type": "string"}, {"type": "null"}],
# "default": None,
# "description": "The ID of the Tool call.",
# },
# },
# "required": ["tool_name", "arguments"],
# },
# "ToolCallResult": {
# "type": "object",
# "properties": {
# "result": {"type": "string", "description": "The result of the Tool invocation."},
# "origin": {
# "$ref": "#/$defs/ToolCall",
# "description": "The Tool call that produced this result.",
# },
# "error": {
# "type": "boolean",
# "description": "Whether the Tool invocation resulted in an error.",
# },
# },
# "required": ["result", "origin", "error"],
# },
},
"description": "A test function",
"properties": {
"input_name": {
"description": "A list of chat messages",
"items": {"$ref": "#/$defs/ChatMessage"},
"type": "array",
}
},
"required": ["input_name"],
"type": "object",
}
# assert strict_function_spec == expected_spec

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down Expand Up @@ -1044,3 +1176,25 @@ def test_live_run_with_toolset(self, tools):
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"

# TODO Re-enable once unit tests are working
# @pytest.mark.skipif(
# not os.environ.get("OPENAI_API_KEY", None),
# reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
# )
# @pytest.mark.integration
# def test_live_run_with_tools_strict(self, tools):
# chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
# component = OpenAIChatGenerator(tools=tools, tools_strict=True)
# results = component.run(chat_messages)
# assert len(results["replies"]) == 1
# message = results["replies"][0]
#
# assert not message.texts
# assert not message.text
# assert message.tool_calls
# tool_call = message.tool_call
# assert isinstance(tool_call, ToolCall)
# assert tool_call.tool_name == "weather"
# assert tool_call.arguments == {"city": "Paris"}
# assert message.meta["finish_reason"] == "tool_calls"
Loading