Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backend/onyx/configs/tool_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os


IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get(
"IMAGE_GENERATION_OUTPUT_FORMAT", "b64_json"
)

# if specified, will pass through request headers to the call to API calls made by custom tools
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/tools/built_in_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class InCodeToolInfo(TypedDict):
InCodeToolInfo(
cls=ImageGenerationTool,
description=(
"The Image Generation Action allows the assistant to use DALL-E 3 to generate images. "
"The Image Generation Action allows the assistant to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
"The action will be used when the user asks the assistant to generate an image."
),
in_code_tool_id=ImageGenerationTool.__name__,
Expand Down
4 changes: 2 additions & 2 deletions backend/onyx/tools/tool_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
if llm and llm.config.api_key and llm.config.model_provider == "openai":
return LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
model_name="gpt-image-1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider adding a constant for 'gpt-image-1' to avoid magic strings, similar to how AZURE_DALLE_DEPLOYMENT_NAME is used

temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
Expand Down Expand Up @@ -90,7 +90,7 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:

return LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
model_name="gpt-image-1",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,17 @@ def __init__(
api_key: str,
api_base: str | None,
api_version: str | None,
model: str = "dall-e-3",
model: str = "gpt-image-1",
num_imgs: int = 2,
additional_headers: dict[str, str] | None = None,
output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT,
) -> None:

if model == "gpt-image-1" and output_format == ImageFormat.URL:
raise ValueError(
"gpt-image-1 does not support URL format. Please use BASE64 format."
)

self.api_key = api_key
self.api_base = api_base
self.api_version = api_version
Expand Down Expand Up @@ -198,12 +204,20 @@ def _generate_image(
self, prompt: str, shape: ImageShape, format: ImageFormat
) -> ImageGenerationResponse:
if shape == ImageShape.LANDSCAPE:
size = "1792x1024"
if self.model == "gpt-image-1":
size = "1536x1024"
else:
size = "1792x1024"
elif shape == ImageShape.PORTRAIT:
size = "1024x1792"
if self.model == "gpt-image-1":
size = "1024x1536"
else:
size = "1024x1792"
else:
size = "1024x1024"

logger.debug(
f"Generating image with model: {self.model}, size: {size}, format: {format}"
)
try:
response = image_generation(
prompt=prompt,
Expand All @@ -224,8 +238,12 @@ def _generate_image(
url = None
image_data = response.data[0]["b64_json"]

revised_prompt = response.data[0].get("revised_prompt")
if revised_prompt is None:
revised_prompt = prompt

return ImageGenerationResponse(
revised_prompt=response.data[0]["revised_prompt"],
revised_prompt=revised_prompt,
url=url,
image_data=image_data,
)
Expand Down
2 changes: 1 addition & 1 deletion backend/requirements/default.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ langchainhub==0.1.21
langgraph==0.2.72
langgraph-checkpoint==2.0.13
langgraph-sdk==0.1.44
litellm==1.69.0
litellm==1.72.2
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.12.28
Expand Down
139 changes: 139 additions & 0 deletions backend/tests/integration/tests/tools/test_image_generation_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os

import pytest

from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageFormat
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)


@pytest.fixture
def dalle3_tool() -> ImageGenerationTool:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY environment variable not set")

return ImageGenerationTool(
api_key=api_key,
api_base=None,
api_version=None,
model="dall-e-3",
num_imgs=1,
output_format=ImageFormat.URL,
)


@pytest.fixture
def gpt_image_tool() -> ImageGenerationTool:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY environment variable not set")

return ImageGenerationTool(
api_key=api_key,
api_base=None,
api_version=None,
model="gpt-image-1",
output_format=ImageFormat.BASE64,
num_imgs=1,
)


def test_dalle3_generates_image_url_successfully(
dalle3_tool: ImageGenerationTool,
) -> None:
# Run the tool with a simple prompt
results = list(dalle3_tool.run(prompt="A simple red circle"))

# Verify we get a response
assert len(results) == 1
tool_response = results[0]

# Check response structure
assert tool_response.id == IMAGE_GENERATION_RESPONSE_ID
assert isinstance(tool_response.response, list)
assert len(tool_response.response) == 1

# Check ImageGenerationResponse content
image_response = tool_response.response[0]
assert isinstance(image_response, ImageGenerationResponse)
assert image_response.revised_prompt is not None
assert len(image_response.revised_prompt) > 0
assert image_response.url is not None
assert image_response.url.startswith("https://")
assert image_response.image_data is None


def test_dalle3_with_base64_format() -> None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY environment variable not set")
Comment on lines +75 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: test_dalle3_with_base64_format doesn't use the fixture pattern established by other tests, leading to code duplication


# Create tool with base64 format
tool = ImageGenerationTool(
api_key=api_key,
api_base=None,
api_version=None,
model="dall-e-3",
output_format=ImageFormat.BASE64,
num_imgs=1,
)

# Run the tool
results = list(tool.run(prompt="A simple blue square", shape="square"))

# Verify response
assert len(results) == 1
image_response = results[0].response[0]
assert image_response.url is None
assert image_response.image_data is not None
assert len(image_response.image_data) > 0


def test_gpt_image_1_generates_base64_successfully(
gpt_image_tool: ImageGenerationTool,
) -> None:
# Run the tool with a simple prompt
results = list(gpt_image_tool.run(prompt="A simple red circle"))

# Verify we get a response
assert len(results) == 1
tool_response = results[0]

# Check response structure
assert tool_response.id == IMAGE_GENERATION_RESPONSE_ID
assert isinstance(tool_response.response, list)
assert len(tool_response.response) == 1

# Check ImageGenerationResponse content
image_response = tool_response.response[0]
assert isinstance(image_response, ImageGenerationResponse)
assert image_response.revised_prompt is not None
assert len(image_response.revised_prompt) > 0
assert image_response.url is None
assert image_response.image_data is not None
assert len(image_response.image_data) > 0


def test_gpt_image_1_with_url_format_fails() -> None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY environment variable not set")

# This should fail during tool creation since gpt-image-1 doesn't support URL format
with pytest.raises(ValueError, match="gpt-image-1 does not support URL format"):
ImageGenerationTool(
api_key=api_key,
api_base=None,
api_version=None,
model="gpt-image-1",
output_format=ImageFormat.URL,
num_imgs=1,
)
Loading