Skip to content

Commit 5c67f39

Browse files
Enhancement/gpt4o image gen support (onyx-dot-app#4859)
* initial model switching changes * Update image generation output format and revise prompt handling * Add validation for output format in ImageGenerationTool and implement tests --------- Co-authored-by: Subash <subash@onyx.app>
1 parent 96aaf78 commit 5c67f39

File tree

6 files changed

+169
-10
lines changed

6 files changed

+169
-10
lines changed

backend/onyx/configs/tool_configs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import os
33

44

5-
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
5+
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get(
6+
"IMAGE_GENERATION_OUTPUT_FORMAT", "b64_json"
7+
)
68

79
# if specified, will pass through request headers to the call to API calls made by custom tools
810
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None

backend/onyx/tools/built_in_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class InCodeToolInfo(TypedDict):
3939
InCodeToolInfo(
4040
cls=ImageGenerationTool,
4141
description=(
42-
"The Image Generation Action allows the assistant to use DALL-E 3 to generate images. "
42+
"The Image Generation Action allows the assistant to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
4343
"The action will be used when the user asks the assistant to generate an image."
4444
),
4545
in_code_tool_id=ImageGenerationTool.__name__,

backend/onyx/tools/tool_constructor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
5353
if llm and llm.config.api_key and llm.config.model_provider == "openai":
5454
return LLMConfig(
5555
model_provider=llm.config.model_provider,
56-
model_name="dall-e-3",
56+
model_name="gpt-image-1",
5757
temperature=GEN_AI_TEMPERATURE,
5858
api_key=llm.config.api_key,
5959
api_base=llm.config.api_base,
@@ -90,7 +90,7 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
9090

9191
return LLMConfig(
9292
model_provider=openai_provider.provider,
93-
model_name="dall-e-3",
93+
model_name="gpt-image-1",
9494
temperature=GEN_AI_TEMPERATURE,
9595
api_key=openai_provider.api_key,
9696
api_base=openai_provider.api_base,

backend/onyx/tools/tool_implementations/images/image_generation_tool.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,17 @@ def __init__(
9090
api_key: str,
9191
api_base: str | None,
9292
api_version: str | None,
93-
model: str = "dall-e-3",
93+
model: str = "gpt-image-1",
9494
num_imgs: int = 2,
9595
additional_headers: dict[str, str] | None = None,
9696
output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT,
9797
) -> None:
98+
99+
if model == "gpt-image-1" and output_format == ImageFormat.URL:
100+
raise ValueError(
101+
"gpt-image-1 does not support URL format. Please use BASE64 format."
102+
)
103+
98104
self.api_key = api_key
99105
self.api_base = api_base
100106
self.api_version = api_version
@@ -198,12 +204,20 @@ def _generate_image(
198204
self, prompt: str, shape: ImageShape, format: ImageFormat
199205
) -> ImageGenerationResponse:
200206
if shape == ImageShape.LANDSCAPE:
201-
size = "1792x1024"
207+
if self.model == "gpt-image-1":
208+
size = "1536x1024"
209+
else:
210+
size = "1792x1024"
202211
elif shape == ImageShape.PORTRAIT:
203-
size = "1024x1792"
212+
if self.model == "gpt-image-1":
213+
size = "1024x1536"
214+
else:
215+
size = "1024x1792"
204216
else:
205217
size = "1024x1024"
206-
218+
logger.debug(
219+
f"Generating image with model: {self.model}, size: {size}, format: {format}"
220+
)
207221
try:
208222
response = image_generation(
209223
prompt=prompt,
@@ -224,8 +238,12 @@ def _generate_image(
224238
url = None
225239
image_data = response.data[0]["b64_json"]
226240

241+
revised_prompt = response.data[0].get("revised_prompt")
242+
if revised_prompt is None:
243+
revised_prompt = prompt
244+
227245
return ImageGenerationResponse(
228-
revised_prompt=response.data[0]["revised_prompt"],
246+
revised_prompt=revised_prompt,
229247
url=url,
230248
image_data=image_data,
231249
)

backend/requirements/default.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ langchainhub==0.1.21
3939
langgraph==0.2.72
4040
langgraph-checkpoint==2.0.13
4141
langgraph-sdk==0.1.44
42-
litellm==1.69.0
42+
litellm==1.72.2
4343
lxml==5.3.0
4444
lxml_html_clean==0.2.2
4545
llama-index==0.12.28
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import os
2+
3+
import pytest
4+
5+
from onyx.tools.tool_implementations.images.image_generation_tool import (
6+
IMAGE_GENERATION_RESPONSE_ID,
7+
)
8+
from onyx.tools.tool_implementations.images.image_generation_tool import ImageFormat
9+
from onyx.tools.tool_implementations.images.image_generation_tool import (
10+
ImageGenerationResponse,
11+
)
12+
from onyx.tools.tool_implementations.images.image_generation_tool import (
13+
ImageGenerationTool,
14+
)
15+
16+
17+
@pytest.fixture
18+
def dalle3_tool() -> ImageGenerationTool:
19+
api_key = os.getenv("OPENAI_API_KEY")
20+
if not api_key:
21+
pytest.skip("OPENAI_API_KEY environment variable not set")
22+
23+
return ImageGenerationTool(
24+
api_key=api_key,
25+
api_base=None,
26+
api_version=None,
27+
model="dall-e-3",
28+
num_imgs=1,
29+
output_format=ImageFormat.URL,
30+
)
31+
32+
33+
@pytest.fixture
34+
def gpt_image_tool() -> ImageGenerationTool:
35+
api_key = os.getenv("OPENAI_API_KEY")
36+
if not api_key:
37+
pytest.skip("OPENAI_API_KEY environment variable not set")
38+
39+
return ImageGenerationTool(
40+
api_key=api_key,
41+
api_base=None,
42+
api_version=None,
43+
model="gpt-image-1",
44+
output_format=ImageFormat.BASE64,
45+
num_imgs=1,
46+
)
47+
48+
49+
def test_dalle3_generates_image_url_successfully(
50+
dalle3_tool: ImageGenerationTool,
51+
) -> None:
52+
# Run the tool with a simple prompt
53+
results = list(dalle3_tool.run(prompt="A simple red circle"))
54+
55+
# Verify we get a response
56+
assert len(results) == 1
57+
tool_response = results[0]
58+
59+
# Check response structure
60+
assert tool_response.id == IMAGE_GENERATION_RESPONSE_ID
61+
assert isinstance(tool_response.response, list)
62+
assert len(tool_response.response) == 1
63+
64+
# Check ImageGenerationResponse content
65+
image_response = tool_response.response[0]
66+
assert isinstance(image_response, ImageGenerationResponse)
67+
assert image_response.revised_prompt is not None
68+
assert len(image_response.revised_prompt) > 0
69+
assert image_response.url is not None
70+
assert image_response.url.startswith("https://")
71+
assert image_response.image_data is None
72+
73+
74+
def test_dalle3_with_base64_format() -> None:
75+
api_key = os.getenv("OPENAI_API_KEY")
76+
if not api_key:
77+
pytest.skip("OPENAI_API_KEY environment variable not set")
78+
79+
# Create tool with base64 format
80+
tool = ImageGenerationTool(
81+
api_key=api_key,
82+
api_base=None,
83+
api_version=None,
84+
model="dall-e-3",
85+
output_format=ImageFormat.BASE64,
86+
num_imgs=1,
87+
)
88+
89+
# Run the tool
90+
results = list(tool.run(prompt="A simple blue square", shape="square"))
91+
92+
# Verify response
93+
assert len(results) == 1
94+
image_response = results[0].response[0]
95+
assert image_response.url is None
96+
assert image_response.image_data is not None
97+
assert len(image_response.image_data) > 0
98+
99+
100+
def test_gpt_image_1_generates_base64_successfully(
101+
gpt_image_tool: ImageGenerationTool,
102+
) -> None:
103+
# Run the tool with a simple prompt
104+
results = list(gpt_image_tool.run(prompt="A simple red circle"))
105+
106+
# Verify we get a response
107+
assert len(results) == 1
108+
tool_response = results[0]
109+
110+
# Check response structure
111+
assert tool_response.id == IMAGE_GENERATION_RESPONSE_ID
112+
assert isinstance(tool_response.response, list)
113+
assert len(tool_response.response) == 1
114+
115+
# Check ImageGenerationResponse content
116+
image_response = tool_response.response[0]
117+
assert isinstance(image_response, ImageGenerationResponse)
118+
assert image_response.revised_prompt is not None
119+
assert len(image_response.revised_prompt) > 0
120+
assert image_response.url is None
121+
assert image_response.image_data is not None
122+
assert len(image_response.image_data) > 0
123+
124+
125+
def test_gpt_image_1_with_url_format_fails() -> None:
126+
api_key = os.getenv("OPENAI_API_KEY")
127+
if not api_key:
128+
pytest.skip("OPENAI_API_KEY environment variable not set")
129+
130+
# This should fail during tool creation since gpt-image-1 doesn't support URL format
131+
with pytest.raises(ValueError, match="gpt-image-1 does not support URL format"):
132+
ImageGenerationTool(
133+
api_key=api_key,
134+
api_base=None,
135+
api_version=None,
136+
model="gpt-image-1",
137+
output_format=ImageFormat.URL,
138+
num_imgs=1,
139+
)

0 commit comments

Comments
 (0)