diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 8f0d99cad4..c31bfeeee1 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -4,6 +4,7 @@ import json import os +from copy import deepcopy from datetime import datetime from typing import Any, Dict, List, Optional, Union @@ -371,6 +372,74 @@ async def run_async( return {"replies": completions} + @staticmethod + def _is_type_object(obj: Any) -> bool: + """ + Check if the object is of type 'object' in OpenAI's schema. + + :param obj: The object to check. + :returns: True if the object is of type 'object', False otherwise. + """ + return isinstance(obj, dict) and "type" in obj and obj["type"] == "object" + + def _strictify_object(self, obj: Any) -> Any: + """ + Recursively updates the sub-objects of the tool specification to follow OpenAI's strict schema. + + This function: + - Sets "additionalProperties" to False in all type = object sections of the tool specification, which is a + requirement for OpenAI's strict schema. + - Removes all non-required fields since all property fields must be required. For ease, we opt to remove all + variables that are not required. + """ + if isinstance(obj, dict): + for key, value in list(obj.items()): + # type = object updates + if self._is_type_object(value): + if "required" not in value: + # If type = object and doesn't have required variables it needs to be removed + del obj[key] + continue + + # If type = object and has required variables, we need to remove all non-required variables + # from the properties + obj[key]["properties"] = { + k: self._strictify_object(v) for k, v in value["properties"].items() if k in value["required"] + } + # Always add and set additionalProperties to False for type = object + obj[key]["additionalProperties"] = False + continue + + obj[key] = self._strictify_object(value) + return obj + + if isinstance(obj, list): + new_items = [] + for item in obj: + # If type = object and doesn't have required variables it needs to be removed + if self._is_type_object(item) and "required" not in item: + continue + new_items.append(self._strictify_object(item)) + return new_items + + return obj + + def _strictify_function_schema(self, function_schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Updates the tool specification object to follow OpenAI's strict schema. + + OpenAI's strict schema is equivalent to their Structured Output schema. + More information on Structured Output can be found + (here)[https://platform.openai.com/docs/guides/structured-outputs/supported-schemas?api-mode=responses]. + + The supported schemas for Structured Outputs can be found + (here)[https://platform.openai.com/docs/guides/structured-outputs/supported-schemas?api-mode=responses#supported-schemas] + """ + # TODO Ideally _strictify would also "repair" any function schema. e.g. a required variable had to be + # removed b/c it had schema {"type": "object", "additionalProperties": True} which is not allowed + # Look at test_strictify_function_schema_chat_message for a real example that is quite messy. + return {**self._strictify_object(deepcopy(function_schema)), **{"strict": True}} + def _prepare_api_call( # noqa: PLR0913 self, *, @@ -398,8 +467,7 @@ def _prepare_api_call( # noqa: PLR0913 for t in tools: function_spec = {**t.tool_spec} if tools_strict: - function_spec["strict"] = True - function_spec["parameters"]["additionalProperties"] = False + function_spec = self._strictify_function_schema(function_spec) tool_definitions.append({"type": "function", "function": function_spec}) openai_tools = {"tools": tool_definitions} diff --git a/releasenotes/notes/better-support-tools-strict-openai-580fc09557785599.yaml b/releasenotes/notes/better-support-tools-strict-openai-580fc09557785599.yaml new file mode 100644 index 0000000000..f6c8539bb8 --- /dev/null +++ b/releasenotes/notes/better-support-tools-strict-openai-580fc09557785599.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Added _strictify function to enforce OpenAI's strict schema for tool specifications. + This update affects the OpenAIChatGenerator component when tools_strict=True is enabled. + The function ensures compatibility by setting additionalProperties to False, converting oneOf to anyOf, and removing all non-required fields as required by OpenAI's Structured Output schema. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index d393e33a1e..b7e6b83a69 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -26,6 +26,14 @@ from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.tools.toolset import Toolset +from test.tools.test_parameters_schema_utils import ( + CHAT_MESSAGE_SCHEMA, + CHAT_ROLE_SCHEMA, + TEXT_CONTENT_SCHEMA, + TOOL_CALL_SCHEMA, + TOOL_CALL_RESULT_SCHEMA, +) + @pytest.fixture def chat_messages(): @@ -908,6 +916,130 @@ def test_convert_usage_chunk_to_streaming_chunk(self): assert result.meta["model"] == "gpt-4o-mini-2024-07-18" assert result.meta["received_at"] is not None + def test_strictify_function_schema(self): + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + function_spec = { + "name": "function_name", + "description": "function_description", + "parameters": { + "type": "object", + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "integer"}, + "param3": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["param1", "param2"], + }, + } + strict_function_spec = component._strictify_function_schema(function_spec) + assert strict_function_spec == { + "name": "function_name", + "description": "function_description", + "parameters": { + "type": "object", + "properties": {"param1": {"type": "string"}, "param2": {"type": "integer"}}, + "required": ["param1", "param2"], + "additionalProperties": False, + }, + "strict": True, + } + + def test_strictify_function_schema_chat_message(self): + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + example_schema = { + "$defs": { + "ChatMessage": CHAT_MESSAGE_SCHEMA, + "ChatRole": CHAT_ROLE_SCHEMA, + "TextContent": TEXT_CONTENT_SCHEMA, + "ToolCall": TOOL_CALL_SCHEMA, + "ToolCallResult": TOOL_CALL_RESULT_SCHEMA, + }, + "description": "A test function", + "properties": { + "input_name": { + "description": "A list of chat messages", + "items": {"$ref": "#/$defs/ChatMessage"}, + "type": "array", + } + }, + "required": ["input_name"], + "type": "object", + } + strict_function_spec = component._strictify_function_schema(example_schema) + expected_spec = { + "$defs": { + "ChatMessage": { + "type": "object", + "properties": { + "role": {"$ref": "#/$defs/ChatRole", "description": "Field 'role' of 'ChatMessage'."}, + "content": { + "type": "array", + "description": "Field 'content' of 'ChatMessage'.", + "items": { + "anyOf": [ + {"$ref": "#/$defs/TextContent"} + # {"$ref": "#/$defs/ToolCall"}, + # {"$ref": "#/$defs/ToolCallResult"}, + ] + }, + }, + }, + "required": ["role", "content"], + }, + "ChatRole": CHAT_ROLE_SCHEMA, + "TextContent": TEXT_CONTENT_SCHEMA, + # TODO `arguments` is the problematic parameter that will be auto-removed but it's also in required + # This means ToolCall itself should be removed (if possible). It can be from ChatMessage + # since it's contained within an anyOf list. + # However, this also affects TOOL_CALL_RESULT_SCHEMA which has ToolCall as a requirement + # This means ToolCallResult should also be removed if possible. + # Then if not possible to remove all the way up then an error should be thrown. + # "ToolCall": { + # "type": "object", + # "properties": { + # "tool_name": {"type": "string", "description": "The name of the Tool to call."}, + # "arguments": { + # "type": "object", + # "description": "The arguments to call the Tool with.", + # "additionalProperties": True, + # }, + # "id": { + # "anyOf": [{"type": "string"}, {"type": "null"}], + # "default": None, + # "description": "The ID of the Tool call.", + # }, + # }, + # "required": ["tool_name", "arguments"], + # }, + # "ToolCallResult": { + # "type": "object", + # "properties": { + # "result": {"type": "string", "description": "The result of the Tool invocation."}, + # "origin": { + # "$ref": "#/$defs/ToolCall", + # "description": "The Tool call that produced this result.", + # }, + # "error": { + # "type": "boolean", + # "description": "Whether the Tool invocation resulted in an error.", + # }, + # }, + # "required": ["result", "origin", "error"], + # }, + }, + "description": "A test function", + "properties": { + "input_name": { + "description": "A list of chat messages", + "items": {"$ref": "#/$defs/ChatMessage"}, + "type": "array", + } + }, + "required": ["input_name"], + "type": "object", + } + # assert strict_function_spec == expected_spec + @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", @@ -1044,3 +1176,25 @@ def test_live_run_with_toolset(self, tools): assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" + + # TODO Re-enable once unit tests are working + # @pytest.mark.skipif( + # not os.environ.get("OPENAI_API_KEY", None), + # reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + # ) + # @pytest.mark.integration + # def test_live_run_with_tools_strict(self, tools): + # chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + # component = OpenAIChatGenerator(tools=tools, tools_strict=True) + # results = component.run(chat_messages) + # assert len(results["replies"]) == 1 + # message = results["replies"][0] + # + # assert not message.texts + # assert not message.text + # assert message.tool_calls + # tool_call = message.tool_call + # assert isinstance(tool_call, ToolCall) + # assert tool_call.tool_name == "weather" + # assert tool_call.arguments == {"city": "Paris"} + # assert message.meta["finish_reason"] == "tool_calls"