From dbc908e00897cb8fcc6d930c692ef3efc3fd3192 Mon Sep 17 00:00:00 2001 From: Subash Date: Sat, 7 Jun 2025 10:10:42 +0530 Subject: [PATCH 1/3] initial model switching changes --- backend/onyx/tools/built_in_tools.py | 2 +- backend/onyx/tools/tool_constructor.py | 4 ++-- .../images/image_generation_tool.py | 12 +++++++++--- backend/requirements/default.txt | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/backend/onyx/tools/built_in_tools.py b/backend/onyx/tools/built_in_tools.py index 5b2f8eab06d..40fa0691e11 100644 --- a/backend/onyx/tools/built_in_tools.py +++ b/backend/onyx/tools/built_in_tools.py @@ -39,7 +39,7 @@ class InCodeToolInfo(TypedDict): InCodeToolInfo( cls=ImageGenerationTool, description=( - "The Image Generation Action allows the assistant to use DALL-E 3 to generate images. " + "The Image Generation Action allows the assistant to use DALL-E 3 or GPT-IMAGE-1 to generate images. " "The action will be used when the user asks the assistant to generate an image." ), in_code_tool_id=ImageGenerationTool.__name__, diff --git a/backend/onyx/tools/tool_constructor.py b/backend/onyx/tools/tool_constructor.py index f9ef0055396..ba4b1b3b196 100644 --- a/backend/onyx/tools/tool_constructor.py +++ b/backend/onyx/tools/tool_constructor.py @@ -53,7 +53,7 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig: if llm and llm.config.api_key and llm.config.model_provider == "openai": return LLMConfig( model_provider=llm.config.model_provider, - model_name="dall-e-3", + model_name="gpt-image-1", temperature=GEN_AI_TEMPERATURE, api_key=llm.config.api_key, api_base=llm.config.api_base, @@ -90,7 +90,7 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig: return LLMConfig( model_provider=openai_provider.provider, - model_name="dall-e-3", + model_name="gpt-image-1", temperature=GEN_AI_TEMPERATURE, api_key=openai_provider.api_key, api_base=openai_provider.api_base, diff --git a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py index 3185b4a001d..0e31fe4f7f4 100644 --- a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py @@ -90,7 +90,7 @@ def __init__( api_key: str, api_base: str | None, api_version: str | None, - model: str = "dall-e-3", + model: str = "gpt-image-1", num_imgs: int = 2, additional_headers: dict[str, str] | None = None, output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT, @@ -198,9 +198,15 @@ def _generate_image( self, prompt: str, shape: ImageShape, format: ImageFormat ) -> ImageGenerationResponse: if shape == ImageShape.LANDSCAPE: - size = "1792x1024" + if self.model == "gpt-image-1": + size = "1536x1024" + else: + size = "1792x1024" elif shape == ImageShape.PORTRAIT: - size = "1024x1792" + if self.model == "gpt-image-1": + size = "1024x1536" + else: + size = "1024x1792" else: size = "1024x1024" diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 870c1da56ab..657318f9dee 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -39,7 +39,7 @@ langchainhub==0.1.21 langgraph==0.2.72 langgraph-checkpoint==2.0.13 langgraph-sdk==0.1.44 -litellm==1.69.0 +litellm==1.72.1 lxml==5.3.0 lxml_html_clean==0.2.2 llama-index==0.12.28 From 9d994a16a7e8e9a9a30987135c8b27c28906218a Mon Sep 17 00:00:00 2001 From: Subash Date: Tue, 10 Jun 2025 07:17:43 +0530 Subject: [PATCH 2/3] Update image generation output format and revise prompt handling --- backend/onyx/configs/tool_configs.py | 4 +++- .../tool_implementations/images/image_generation_tool.py | 6 +++++- backend/requirements/default.txt | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/backend/onyx/configs/tool_configs.py b/backend/onyx/configs/tool_configs.py index 955d3af366e..78da1fc201e 100644 --- a/backend/onyx/configs/tool_configs.py +++ b/backend/onyx/configs/tool_configs.py @@ -2,7 +2,9 @@ import os -IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url") +IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get( + "IMAGE_GENERATION_OUTPUT_FORMAT", "b64_json" +) # if specified, will pass through request headers to the call to API calls made by custom tools CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None diff --git a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py index 0e31fe4f7f4..48642270e55 100644 --- a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py @@ -230,8 +230,12 @@ def _generate_image( url = None image_data = response.data[0]["b64_json"] + revised_prompt = response.data[0].get("revised_prompt") + if revised_prompt is None: + revised_prompt = prompt + return ImageGenerationResponse( - revised_prompt=response.data[0]["revised_prompt"], + revised_prompt=revised_prompt, url=url, image_data=image_data, ) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 657318f9dee..930a63dd3a5 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -39,7 +39,7 @@ langchainhub==0.1.21 langgraph==0.2.72 langgraph-checkpoint==2.0.13 langgraph-sdk==0.1.44 -litellm==1.72.1 +litellm==1.72.2 lxml==5.3.0 lxml_html_clean==0.2.2 llama-index==0.12.28 From 79508fa3cb1956b51f4af1e1c03dd84b73c746ec Mon Sep 17 00:00:00 2001 From: Subash Date: Tue, 10 Jun 2025 09:02:56 +0530 Subject: [PATCH 3/3] Add validation for output format in ImageGenerationTool and implement tests --- .../images/image_generation_tool.py | 10 +- .../tests/tools/test_image_generation_tool.py | 139 ++++++++++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 backend/tests/integration/tests/tools/test_image_generation_tool.py diff --git a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py index 48642270e55..89cf431e384 100644 --- a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py @@ -95,6 +95,12 @@ def __init__( additional_headers: dict[str, str] | None = None, output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT, ) -> None: + + if model == "gpt-image-1" and output_format == ImageFormat.URL: + raise ValueError( + "gpt-image-1 does not support URL format. Please use BASE64 format." + ) + self.api_key = api_key self.api_base = api_base self.api_version = api_version @@ -209,7 +215,9 @@ def _generate_image( size = "1024x1792" else: size = "1024x1024" - + logger.debug( + f"Generating image with model: {self.model}, size: {size}, format: {format}" + ) try: response = image_generation( prompt=prompt, diff --git a/backend/tests/integration/tests/tools/test_image_generation_tool.py b/backend/tests/integration/tests/tools/test_image_generation_tool.py new file mode 100644 index 00000000000..1ad46211e23 --- /dev/null +++ b/backend/tests/integration/tests/tools/test_image_generation_tool.py @@ -0,0 +1,139 @@ +import os + +import pytest + +from onyx.tools.tool_implementations.images.image_generation_tool import ( + IMAGE_GENERATION_RESPONSE_ID, +) +from onyx.tools.tool_implementations.images.image_generation_tool import ImageFormat +from onyx.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationResponse, +) +from onyx.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationTool, +) + + +@pytest.fixture +def dalle3_tool() -> ImageGenerationTool: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY environment variable not set") + + return ImageGenerationTool( + api_key=api_key, + api_base=None, + api_version=None, + model="dall-e-3", + num_imgs=1, + output_format=ImageFormat.URL, + ) + + +@pytest.fixture +def gpt_image_tool() -> ImageGenerationTool: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY environment variable not set") + + return ImageGenerationTool( + api_key=api_key, + api_base=None, + api_version=None, + model="gpt-image-1", + output_format=ImageFormat.BASE64, + num_imgs=1, + ) + + +def test_dalle3_generates_image_url_successfully( + dalle3_tool: ImageGenerationTool, +) -> None: + # Run the tool with a simple prompt + results = list(dalle3_tool.run(prompt="A simple red circle")) + + # Verify we get a response + assert len(results) == 1 + tool_response = results[0] + + # Check response structure + assert tool_response.id == IMAGE_GENERATION_RESPONSE_ID + assert isinstance(tool_response.response, list) + assert len(tool_response.response) == 1 + + # Check ImageGenerationResponse content + image_response = tool_response.response[0] + assert isinstance(image_response, ImageGenerationResponse) + assert image_response.revised_prompt is not None + assert len(image_response.revised_prompt) > 0 + assert image_response.url is not None + assert image_response.url.startswith("https://") + assert image_response.image_data is None + + +def test_dalle3_with_base64_format() -> None: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY environment variable not set") + + # Create tool with base64 format + tool = ImageGenerationTool( + api_key=api_key, + api_base=None, + api_version=None, + model="dall-e-3", + output_format=ImageFormat.BASE64, + num_imgs=1, + ) + + # Run the tool + results = list(tool.run(prompt="A simple blue square", shape="square")) + + # Verify response + assert len(results) == 1 + image_response = results[0].response[0] + assert image_response.url is None + assert image_response.image_data is not None + assert len(image_response.image_data) > 0 + + +def test_gpt_image_1_generates_base64_successfully( + gpt_image_tool: ImageGenerationTool, +) -> None: + # Run the tool with a simple prompt + results = list(gpt_image_tool.run(prompt="A simple red circle")) + + # Verify we get a response + assert len(results) == 1 + tool_response = results[0] + + # Check response structure + assert tool_response.id == IMAGE_GENERATION_RESPONSE_ID + assert isinstance(tool_response.response, list) + assert len(tool_response.response) == 1 + + # Check ImageGenerationResponse content + image_response = tool_response.response[0] + assert isinstance(image_response, ImageGenerationResponse) + assert image_response.revised_prompt is not None + assert len(image_response.revised_prompt) > 0 + assert image_response.url is None + assert image_response.image_data is not None + assert len(image_response.image_data) > 0 + + +def test_gpt_image_1_with_url_format_fails() -> None: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY environment variable not set") + + # This should fail during tool creation since gpt-image-1 doesn't support URL format + with pytest.raises(ValueError, match="gpt-image-1 does not support URL format"): + ImageGenerationTool( + api_key=api_key, + api_base=None, + api_version=None, + model="gpt-image-1", + output_format=ImageFormat.URL, + num_imgs=1, + )