Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from datetime import datetime
from typing import cast

Expand Down Expand Up @@ -28,11 +29,76 @@
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger

logger = setup_logger()


def _extract_tool_args(
raw_query: str | dict[str, object], metadata: dict[str, str] | None
) -> tuple[str, ImageShape | None]:
prompt = str(raw_query).strip()
shape: ImageShape | None = None
parsed: dict[str, object] | None = None

if isinstance(raw_query, dict):
parsed = raw_query
elif isinstance(raw_query, str):
try:
parsed_candidate = json.loads(raw_query)
except json.JSONDecodeError:
parsed_candidate = None
if isinstance(parsed_candidate, dict):
parsed = parsed_candidate

if isinstance(parsed, dict):
prompt_value = parsed.get("prompt")
if isinstance(prompt_value, str) and prompt_value.strip():
prompt = prompt_value.strip()

shape_value = parsed.get("shape")
if isinstance(shape_value, str):
try:
shape = ImageShape(shape_value.lower())
except ValueError:
logger.debug("Unsupported image shape requested: %s", shape_value)

if not shape and metadata:
metadata_shape = metadata.get("shape") if metadata else None
if isinstance(metadata_shape, str):
try:
shape = ImageShape(metadata_shape.lower())
except ValueError:
logger.debug("Unsupported image shape in metadata: %s", metadata_shape)

if not prompt:
prompt = str(raw_query)

return prompt, shape


def _expected_dimensions(
shape: ImageShape | None, model: str
) -> tuple[int | None, int | None]:
if shape is None:
return None, None

if shape == ImageShape.LANDSCAPE:
size = "1536x1024" if model == "gpt-image-1" else "1792x1024"
elif shape == ImageShape.PORTRAIT:
size = "1024x1536" if model == "gpt-image-1" else "1024x1792"
else:
size = "1024x1024"

try:
width_str, height_str = size.split("x")
return int(width_str), int(height_str)
except ValueError:
logger.debug("Unable to parse expected size '%s'", size)
return None, None


def image_generation(
state: BranchInput,
config: RunnableConfig,
Expand Down Expand Up @@ -69,12 +135,30 @@ def image_generation(
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []

for tool_response in image_tool.run(prompt=branch_query):
image_prompt, image_shape_enum = _extract_tool_args(
branch_query, image_tool_info.metadata
)

expected_width, expected_height = _expected_dimensions(
image_shape_enum, image_tool.model
)

shape_for_text = image_shape_enum.value if image_shape_enum else None

run_kwargs: dict[str, str] = {"prompt": image_prompt}
if image_shape_enum:
run_kwargs["shape"] = image_shape_enum.value
Comment on lines +148 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Type annotation inconsistency - run_kwargs is declared as dict[str, str] but image_shape_enum.value could be None, causing a type mismatch if shape is added

Prompt To Fix With AI
This is a comment left during a code review.
Path: backend/onyx/agents/agent_search/dr/sub_agents/image_generation/dr_image_generation_2_act.py
Line: 191:193

Comment:
**logic:** Type annotation inconsistency - run_kwargs is declared as `dict[str, str]` but `image_shape_enum.value` could be None, causing a type mismatch if shape is added

How can I resolve this? If you propose a fix, please make it concise.


for tool_response in image_tool.run(**run_kwargs):
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
state.current_step_nr,
ImageGenerationToolHeartbeat(),
ImageGenerationToolHeartbeat(
shape=image_shape_enum.value if image_shape_enum else None,
width=expected_width,
height=expected_height,
),
writer,
)
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
Expand All @@ -90,32 +174,58 @@ def image_generation(
],
)

final_generated_images = [
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
final_generated_images: list[GeneratedImage] = []
for file_id, img in zip(file_ids, image_generation_responses):
response_shape = img.shape
if isinstance(response_shape, ImageShape):
response_shape_value = response_shape.value
else:
response_shape_value = response_shape

if not response_shape_value and image_shape_enum:
response_shape_value = image_shape_enum.value

width = img.width or expected_width
height = img.height or expected_height

final_generated_images.append(
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
width=width,
height=height,
shape=response_shape_value,
)
)
for file_id, img in zip(file_ids, image_generation_responses)
]

logger.debug(
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)

if not shape_for_text and final_generated_images:
shape_for_text = final_generated_images[0].shape

# Create answer string describing the generated images
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
image_descriptions.append(f"Image {i}: {img.revised_prompt}")

request_details = image_prompt
if shape_for_text:
request_details = f"{image_prompt} (shape: {shape_for_text})"

answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {branch_query}\n\n"
f"Generated {len(final_generated_images)} image(s) based on the request: {request_details}\n\n"
+ "\n".join(image_descriptions)
)
reasoning = f"Used image generation tool to create {len(final_generated_images)} image(s) based on the user's request."
reasoning = (
f"Used image generation tool to create {len(final_generated_images)} image(s)"
f" based on the user's request{' with shape ' + shape_for_text if shape_for_text else ''}."
)
else:
answer_string = f"Failed to generate images for request: {branch_query}"
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."

return BranchUpdate(
Expand All @@ -125,7 +235,7 @@ def image_generation(
tool_id=image_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
question=image_prompt,
answer=answer_string,
claims=[],
cited_documents={},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
width: int | None = None
height: int | None = None
shape: str | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider using an enum for shape values instead of string to maintain consistency with ImageShape enum used elsewhere in the codebase

Prompt To Fix With AI
This is a comment left during a code review.
Path: backend/onyx/agents/agent_search/dr/sub_agents/image_generation/models.py
Line: 10:10

Comment:
**style:** Consider using an enum for shape values instead of string to maintain consistency with ImageShape enum used elsewhere in the codebase

How can I resolve this? If you propose a fix, please make it concise.



# Needed for PydanticType
Expand Down
8 changes: 8 additions & 0 deletions backend/onyx/prompts/dr_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@
""",
DRPath.CLOSER.value: f"""if the tool is {CLOSER}, the list of questions should simply be \
['Answer the original question with the information you have.'].
""",
DRPath.IMAGE_GENERATION.value: """
if the tool is Image Generation, respond with a list that contains exactly one JSON object
string describing the tool call. The JSON must include a "prompt" field with the text to
render. When the user specifies or implies an orientation, also include a "shape" field whose
value is one of "square", "landscape", or "portrait" (use "landscape" for wide/horizontal
requests and "portrait" for tall/vertical ones). Example: {"prompt": "Create a poster of a
coral reef", "shape": "landscape"}. Do not surround the JSON with backticks or narration.
""",
}

Expand Down
3 changes: 3 additions & 0 deletions backend/onyx/server/query_and_chat/streaming_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class ImageGenerationToolDelta(BaseObj):

class ImageGenerationToolHeartbeat(BaseObj):
type: Literal["image_generation_tool_heartbeat"] = "image_generation_tool_heartbeat"
shape: str | None = None
width: int | None = None
height: int | None = None
Comment on lines +75 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider using the ImageShape enum from the tool implementation instead of str | None for the shape field to ensure type safety and consistency across the codebase

Prompt To Fix With AI
This is a comment left during a code review.
Path: backend/onyx/server/query_and_chat/streaming_models.py
Line: 75:77

Comment:
**style:** Consider using the ImageShape enum from the tool implementation instead of `str | None` for the shape field to ensure type safety and consistency across the codebase

How can I resolve this? If you propose a fix, please make it concise.



class CustomToolStart(BaseObj):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ class ImageFormat(str, Enum):
_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT)


class ImageGenerationResponse(BaseModel):
revised_prompt: str
url: str | None
image_data: str | None
def _parse_size(size: str) -> tuple[int | None, int | None]:
try:
width_str, height_str = size.split("x")
return int(width_str), int(height_str)
except (ValueError, AttributeError):
logger.debug("Unable to parse image size '%s'", size)
return None, None


class ImageShape(str, Enum):
Expand All @@ -88,6 +91,15 @@ class ImageShape(str, Enum):
LANDSCAPE = "landscape"


class ImageGenerationResponse(BaseModel):
revised_prompt: str
url: str | None
image_data: str | None
width: int | None = None
height: int | None = None
shape: ImageShape | None = None


# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
_NAME = "run_image_generation"
Expand Down Expand Up @@ -257,6 +269,8 @@ def _generate_image(
size = "1024x1792"
else:
size = "1024x1024"

width, height = _parse_size(size)
logger.debug(
f"Generating image with model: {self.model}, size: {size}, format: {format}"
)
Expand Down Expand Up @@ -293,6 +307,9 @@ def _generate_image(
revised_prompt=revised_prompt,
url=url,
image_data=image_data,
width=width,
height=height,
shape=shape,
)

except requests.RequestException as e:
Expand Down
45 changes: 39 additions & 6 deletions web/src/app/chat/components/files/images/InMessageImage.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import { useState } from "react";
import { useMemo, useState } from "react";
import { FiDownload } from "react-icons/fi";
import { FullImageModal } from "./FullImageModal";
import { buildImgUrl } from "./utils";

export function InMessageImage({ fileId }: { fileId: string }) {
type InMessageImageProps = {
fileId: string;
width?: number | null;
height?: number | null;
};

export function InMessageImage({ fileId, width, height }: InMessageImageProps) {
const [fullImageShowing, setFullImageShowing] = useState(false);
const [imageLoaded, setImageLoaded] = useState(false);
const [naturalDimensions, setNaturalDimensions] = useState<{
width: number;
height: number;
} | null>(null);

const handleDownload = async (e: React.MouseEvent) => {
e.stopPropagation(); // Prevent opening the full image modal
Expand All @@ -26,6 +36,20 @@ export function InMessageImage({ fileId }: { fileId: string }) {
}
};

const resolvedDimensions = useMemo(() => {
if (width && height) {
return { width, height };
}
if (naturalDimensions) {
return naturalDimensions;
}
return null;
}, [width, height, naturalDimensions]);

const aspectRatio = resolvedDimensions
? `${resolvedDimensions.width} / ${resolvedDimensions.height}`
: "1 / 1";

return (
<>
<FullImageModal
Expand All @@ -34,7 +58,7 @@ export function InMessageImage({ fileId }: { fileId: string }) {
onOpenChange={(open) => setFullImageShowing(open)}
/>

<div className="relative w-full h-full max-w-96 max-h-96 group">
<div className="relative w-full max-w-96 group" style={{ aspectRatio }}>
{!imageLoaded && (
<div className="absolute inset-0 bg-background-200 animate-pulse rounded-lg" />
)}
Expand All @@ -43,16 +67,25 @@ export function InMessageImage({ fileId }: { fileId: string }) {
width={1200}
height={1200}
alt="Chat Message Image"
onLoad={() => setImageLoaded(true)}
onLoad={(event) => {
setImageLoaded(true);
if (!width || !height) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The condition !width || !height could be problematic if one dimension is 0 but the other is valid. Consider using !width && !height or checking for falsy values more explicitly.

Prompt To Fix With AI
This is a comment left during a code review.
Path: web/src/app/chat/components/files/images/InMessageImage.tsx
Line: 72:72

Comment:
**logic:** The condition `!width || !height` could be problematic if one dimension is 0 but the other is valid. Consider using `!width && !height` or checking for falsy values more explicitly.

How can I resolve this? If you propose a fix, please make it concise.

const { naturalWidth, naturalHeight } = event.currentTarget;
if (naturalWidth && naturalHeight) {
setNaturalDimensions({
width: naturalWidth,
height: naturalHeight,
});
}
}
}}
className={`
object-contain
object-left
overflow-hidden
rounded-lg
w-full
h-full
max-w-96
max-h-96
transition-opacity
duration-300
cursor-pointer
Expand Down
22 changes: 20 additions & 2 deletions web/src/app/chat/components/tools/GeneratingImageDisplay.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import React, { useState, useEffect, useRef } from "react";

export default function GeneratingImageDisplay({ isCompleted = false }) {
type GeneratingImageDisplayProps = {
isCompleted?: boolean;
width?: number | null;
height?: number | null;
};

export default function GeneratingImageDisplay({
isCompleted = false,
width,
height,
}: GeneratingImageDisplayProps) {
const [progress, setProgress] = useState(0);
const progressRef = useRef(0);
const animationRef = useRef<number>();
Expand Down Expand Up @@ -55,8 +65,16 @@ export default function GeneratingImageDisplay({ isCompleted = false }) {
}
}, [isCompleted]);

const aspectWidth = width && width > 0 ? width : null;
const aspectHeight = height && height > 0 ? height : null;
const aspectRatio =
aspectWidth && aspectHeight ? `${aspectWidth} / ${aspectHeight}` : "1 / 1";

return (
<div className="object-cover object-center border border-background-200 bg-background-100 items-center justify-center overflow-hidden flex rounded-lg w-96 h-96 transition-opacity duration-300 opacity-100">
<div
className="border border-background-200 bg-background-100 items-center justify-center overflow-hidden flex rounded-lg w-full max-w-96 transition-opacity duration-300 opacity-100"
style={{ aspectRatio }}
>
<div className="m-auto relative flex">
<svg className="w-16 h-16 transform -rotate-90" viewBox="0 0 100 100">
<circle
Expand Down
Loading
Loading