diff --git a/backend/onyx/agents/agent_search/dr/sub_agents/image_generation/dr_image_generation_2_act.py b/backend/onyx/agents/agent_search/dr/sub_agents/image_generation/dr_image_generation_2_act.py index c5f352acf91..3970f87699a 100644 --- a/backend/onyx/agents/agent_search/dr/sub_agents/image_generation/dr_image_generation_2_act.py +++ b/backend/onyx/agents/agent_search/dr/sub_agents/image_generation/dr_image_generation_2_act.py @@ -1,3 +1,4 @@ +import json from datetime import datetime from typing import cast @@ -28,11 +29,76 @@ from onyx.tools.tool_implementations.images.image_generation_tool import ( ImageGenerationTool, ) +from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape from onyx.utils.logger import setup_logger logger = setup_logger() +def _extract_tool_args( + raw_query: str | dict[str, object], metadata: dict[str, str] | None +) -> tuple[str, ImageShape | None]: + prompt = str(raw_query).strip() + shape: ImageShape | None = None + parsed: dict[str, object] | None = None + + if isinstance(raw_query, dict): + parsed = raw_query + elif isinstance(raw_query, str): + try: + parsed_candidate = json.loads(raw_query) + except json.JSONDecodeError: + parsed_candidate = None + if isinstance(parsed_candidate, dict): + parsed = parsed_candidate + + if isinstance(parsed, dict): + prompt_value = parsed.get("prompt") + if isinstance(prompt_value, str) and prompt_value.strip(): + prompt = prompt_value.strip() + + shape_value = parsed.get("shape") + if isinstance(shape_value, str): + try: + shape = ImageShape(shape_value.lower()) + except ValueError: + logger.debug("Unsupported image shape requested: %s", shape_value) + + if not shape and metadata: + metadata_shape = metadata.get("shape") if metadata else None + if isinstance(metadata_shape, str): + try: + shape = ImageShape(metadata_shape.lower()) + except ValueError: + logger.debug("Unsupported image shape in metadata: %s", metadata_shape) + + if not prompt: + prompt = str(raw_query) + + return prompt, shape + + +def _expected_dimensions( + shape: ImageShape | None, model: str +) -> tuple[int | None, int | None]: + if shape is None: + return None, None + + if shape == ImageShape.LANDSCAPE: + size = "1536x1024" if model == "gpt-image-1" else "1792x1024" + elif shape == ImageShape.PORTRAIT: + size = "1024x1536" if model == "gpt-image-1" else "1024x1792" + else: + size = "1024x1024" + + try: + width_str, height_str = size.split("x") + return int(width_str), int(height_str) + except ValueError: + logger.debug("Unable to parse expected size '%s'", size) + return None, None + + def image_generation( state: BranchInput, config: RunnableConfig, @@ -69,12 +135,30 @@ def image_generation( # Generate images using the image generation tool image_generation_responses: list[ImageGenerationResponse] = [] - for tool_response in image_tool.run(prompt=branch_query): + image_prompt, image_shape_enum = _extract_tool_args( + branch_query, image_tool_info.metadata + ) + + expected_width, expected_height = _expected_dimensions( + image_shape_enum, image_tool.model + ) + + shape_for_text = image_shape_enum.value if image_shape_enum else None + + run_kwargs: dict[str, str] = {"prompt": image_prompt} + if image_shape_enum: + run_kwargs["shape"] = image_shape_enum.value + + for tool_response in image_tool.run(**run_kwargs): if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID: # Stream heartbeat to frontend write_custom_event( state.current_step_nr, - ImageGenerationToolHeartbeat(), + ImageGenerationToolHeartbeat( + shape=image_shape_enum.value if image_shape_enum else None, + width=expected_width, + height=expected_height, + ), writer, ) elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID: @@ -90,32 +174,58 @@ def image_generation( ], ) - final_generated_images = [ - GeneratedImage( - file_id=file_id, - url=build_frontend_file_url(file_id), - revised_prompt=img.revised_prompt, + final_generated_images: list[GeneratedImage] = [] + for file_id, img in zip(file_ids, image_generation_responses): + response_shape = img.shape + if isinstance(response_shape, ImageShape): + response_shape_value = response_shape.value + else: + response_shape_value = response_shape + + if not response_shape_value and image_shape_enum: + response_shape_value = image_shape_enum.value + + width = img.width or expected_width + height = img.height or expected_height + + final_generated_images.append( + GeneratedImage( + file_id=file_id, + url=build_frontend_file_url(file_id), + revised_prompt=img.revised_prompt, + width=width, + height=height, + shape=response_shape_value, + ) ) - for file_id, img in zip(file_ids, image_generation_responses) - ] logger.debug( f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}" ) + if not shape_for_text and final_generated_images: + shape_for_text = final_generated_images[0].shape + # Create answer string describing the generated images if final_generated_images: image_descriptions = [] for i, img in enumerate(final_generated_images, 1): image_descriptions.append(f"Image {i}: {img.revised_prompt}") + request_details = image_prompt + if shape_for_text: + request_details = f"{image_prompt} (shape: {shape_for_text})" + answer_string = ( - f"Generated {len(final_generated_images)} image(s) based on the request: {branch_query}\n\n" + f"Generated {len(final_generated_images)} image(s) based on the request: {request_details}\n\n" + "\n".join(image_descriptions) ) - reasoning = f"Used image generation tool to create {len(final_generated_images)} image(s) based on the user's request." + reasoning = ( + f"Used image generation tool to create {len(final_generated_images)} image(s)" + f" based on the user's request{' with shape ' + shape_for_text if shape_for_text else ''}." + ) else: - answer_string = f"Failed to generate images for request: {branch_query}" + answer_string = f"Failed to generate images for request: {image_prompt}" reasoning = "Image generation tool did not return any results." return BranchUpdate( @@ -125,7 +235,7 @@ def image_generation( tool_id=image_tool_info.tool_id, iteration_nr=iteration_nr, parallelization_nr=parallelization_nr, - question=branch_query, + question=image_prompt, answer=answer_string, claims=[], cited_documents={}, diff --git a/backend/onyx/agents/agent_search/dr/sub_agents/image_generation/models.py b/backend/onyx/agents/agent_search/dr/sub_agents/image_generation/models.py index ed854c93416..0ac62039e53 100644 --- a/backend/onyx/agents/agent_search/dr/sub_agents/image_generation/models.py +++ b/backend/onyx/agents/agent_search/dr/sub_agents/image_generation/models.py @@ -5,6 +5,9 @@ class GeneratedImage(BaseModel): file_id: str url: str revised_prompt: str + width: int | None = None + height: int | None = None + shape: str | None = None # Needed for PydanticType diff --git a/backend/onyx/prompts/dr_prompts.py b/backend/onyx/prompts/dr_prompts.py index 11df26a2d65..48fa3ad49b2 100644 --- a/backend/onyx/prompts/dr_prompts.py +++ b/backend/onyx/prompts/dr_prompts.py @@ -171,6 +171,14 @@ """, DRPath.CLOSER.value: f"""if the tool is {CLOSER}, the list of questions should simply be \ ['Answer the original question with the information you have.']. +""", + DRPath.IMAGE_GENERATION.value: """ +if the tool is Image Generation, respond with a list that contains exactly one JSON object +string describing the tool call. The JSON must include a "prompt" field with the text to +render. When the user specifies or implies an orientation, also include a "shape" field whose +value is one of "square", "landscape", or "portrait" (use "landscape" for wide/horizontal +requests and "portrait" for tall/vertical ones). Example: {"prompt": "Create a poster of a +coral reef", "shape": "landscape"}. Do not surround the JSON with backticks or narration. """, } diff --git a/backend/onyx/server/query_and_chat/streaming_models.py b/backend/onyx/server/query_and_chat/streaming_models.py index 1998e4c3489..84d0f8eed50 100644 --- a/backend/onyx/server/query_and_chat/streaming_models.py +++ b/backend/onyx/server/query_and_chat/streaming_models.py @@ -72,6 +72,9 @@ class ImageGenerationToolDelta(BaseObj): class ImageGenerationToolHeartbeat(BaseObj): type: Literal["image_generation_tool_heartbeat"] = "image_generation_tool_heartbeat" + shape: str | None = None + width: int | None = None + height: int | None = None class CustomToolStart(BaseObj): diff --git a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py index 08adc3ff5bd..5dc6c66bcf1 100644 --- a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py @@ -76,10 +76,13 @@ class ImageFormat(str, Enum): _DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT) -class ImageGenerationResponse(BaseModel): - revised_prompt: str - url: str | None - image_data: str | None +def _parse_size(size: str) -> tuple[int | None, int | None]: + try: + width_str, height_str = size.split("x") + return int(width_str), int(height_str) + except (ValueError, AttributeError): + logger.debug("Unable to parse image size '%s'", size) + return None, None class ImageShape(str, Enum): @@ -88,6 +91,15 @@ class ImageShape(str, Enum): LANDSCAPE = "landscape" +class ImageGenerationResponse(BaseModel): + revised_prompt: str + url: str | None + image_data: str | None + width: int | None = None + height: int | None = None + shape: ImageShape | None = None + + # override_kwargs is not supported for image generation tools class ImageGenerationTool(Tool[None]): _NAME = "run_image_generation" @@ -257,6 +269,8 @@ def _generate_image( size = "1024x1792" else: size = "1024x1024" + + width, height = _parse_size(size) logger.debug( f"Generating image with model: {self.model}, size: {size}, format: {format}" ) @@ -293,6 +307,9 @@ def _generate_image( revised_prompt=revised_prompt, url=url, image_data=image_data, + width=width, + height=height, + shape=shape, ) except requests.RequestException as e: diff --git a/web/src/app/chat/components/files/images/InMessageImage.tsx b/web/src/app/chat/components/files/images/InMessageImage.tsx index e8de1fb40f8..337640e5874 100644 --- a/web/src/app/chat/components/files/images/InMessageImage.tsx +++ b/web/src/app/chat/components/files/images/InMessageImage.tsx @@ -1,11 +1,21 @@ -import { useState } from "react"; +import { useMemo, useState } from "react"; import { FiDownload } from "react-icons/fi"; import { FullImageModal } from "./FullImageModal"; import { buildImgUrl } from "./utils"; -export function InMessageImage({ fileId }: { fileId: string }) { +type InMessageImageProps = { + fileId: string; + width?: number | null; + height?: number | null; +}; + +export function InMessageImage({ fileId, width, height }: InMessageImageProps) { const [fullImageShowing, setFullImageShowing] = useState(false); const [imageLoaded, setImageLoaded] = useState(false); + const [naturalDimensions, setNaturalDimensions] = useState<{ + width: number; + height: number; + } | null>(null); const handleDownload = async (e: React.MouseEvent) => { e.stopPropagation(); // Prevent opening the full image modal @@ -26,6 +36,20 @@ export function InMessageImage({ fileId }: { fileId: string }) { } }; + const resolvedDimensions = useMemo(() => { + if (width && height) { + return { width, height }; + } + if (naturalDimensions) { + return naturalDimensions; + } + return null; + }, [width, height, naturalDimensions]); + + const aspectRatio = resolvedDimensions + ? `${resolvedDimensions.width} / ${resolvedDimensions.height}` + : "1 / 1"; + return ( <> setFullImageShowing(open)} /> -
+
{!imageLoaded && (
)} @@ -43,7 +67,18 @@ export function InMessageImage({ fileId }: { fileId: string }) { width={1200} height={1200} alt="Chat Message Image" - onLoad={() => setImageLoaded(true)} + onLoad={(event) => { + setImageLoaded(true); + if (!width || !height) { + const { naturalWidth, naturalHeight } = event.currentTarget; + if (naturalWidth && naturalHeight) { + setNaturalDimensions({ + width: naturalWidth, + height: naturalHeight, + }); + } + } + }} className={` object-contain object-left @@ -51,8 +86,6 @@ export function InMessageImage({ fileId }: { fileId: string }) { rounded-lg w-full h-full - max-w-96 - max-h-96 transition-opacity duration-300 cursor-pointer diff --git a/web/src/app/chat/components/tools/GeneratingImageDisplay.tsx b/web/src/app/chat/components/tools/GeneratingImageDisplay.tsx index 8457551481e..759d834f761 100644 --- a/web/src/app/chat/components/tools/GeneratingImageDisplay.tsx +++ b/web/src/app/chat/components/tools/GeneratingImageDisplay.tsx @@ -1,6 +1,16 @@ import React, { useState, useEffect, useRef } from "react"; -export default function GeneratingImageDisplay({ isCompleted = false }) { +type GeneratingImageDisplayProps = { + isCompleted?: boolean; + width?: number | null; + height?: number | null; +}; + +export default function GeneratingImageDisplay({ + isCompleted = false, + width, + height, +}: GeneratingImageDisplayProps) { const [progress, setProgress] = useState(0); const progressRef = useRef(0); const animationRef = useRef(); @@ -55,8 +65,16 @@ export default function GeneratingImageDisplay({ isCompleted = false }) { } }, [isCompleted]); + const aspectWidth = width && width > 0 ? width : null; + const aspectHeight = height && height > 0 ? height : null; + const aspectRatio = + aspectWidth && aspectHeight ? `${aspectWidth} / ${aspectHeight}` : "1 / 1"; + return ( -
+
packet.obj.type === PacketType.IMAGE_GENERATION_TOOL_DELTA ) .map((packet) => packet.obj as ImageGenerationToolDelta); + const imageHeartbeats = packets + .filter( + (packet) => packet.obj.type === PacketType.IMAGE_GENERATION_TOOL_HEARTBEAT + ) + .map((packet) => packet.obj as ImageGenerationToolHeartbeat); const imageEnd = packets.find( (packet) => packet.obj.type === PacketType.SECTION_END )?.obj as SectionEnd | null; @@ -29,6 +35,14 @@ function constructCurrentImageState(packets: ImageGenerationToolPacket[]) { const images = imageDeltas.flatMap((delta) => delta?.images || []); const isGenerating = imageStart && !imageEnd; const isComplete = imageStart && imageEnd; + const latestHeartbeat = + imageHeartbeats.length > 0 + ? imageHeartbeats[imageHeartbeats.length - 1] + : null; + const firstImage = images[0]; + const width = firstImage?.width ?? latestHeartbeat?.width ?? null; + const height = firstImage?.height ?? latestHeartbeat?.height ?? null; + const shape = firstImage?.shape ?? latestHeartbeat?.shape ?? null; return { prompt, @@ -36,6 +50,9 @@ function constructCurrentImageState(packets: ImageGenerationToolPacket[]) { isGenerating, isComplete, error: false, // For now, we don't have error state in the packets + width, + height, + shape, }; } @@ -43,7 +60,7 @@ export const ImageToolRenderer: MessageRenderer< ImageGenerationToolPacket, {} > = ({ packets, onComplete, renderType, children }) => { - const { prompt, images, isGenerating, isComplete, error } = + const { prompt, images, isGenerating, isComplete, error, width, height } = constructCurrentImageState(packets); useEffect(() => { @@ -73,7 +90,11 @@ export const ImageToolRenderer: MessageRenderer< content: (
- +
), @@ -96,7 +117,13 @@ export const ImageToolRenderer: MessageRenderer< key={image.file_id || index} className="transition-all group" > - {image.file_id && } + {image.file_id && ( + + )}
))}
diff --git a/web/src/app/chat/services/streamingModels.ts b/web/src/app/chat/services/streamingModels.ts index 10738872bfb..e2e940f21dd 100644 --- a/web/src/app/chat/services/streamingModels.ts +++ b/web/src/app/chat/services/streamingModels.ts @@ -18,6 +18,7 @@ export enum PacketType { SEARCH_TOOL_DELTA = "internal_search_tool_delta", IMAGE_GENERATION_TOOL_START = "image_generation_tool_start", IMAGE_GENERATION_TOOL_DELTA = "image_generation_tool_delta", + IMAGE_GENERATION_TOOL_HEARTBEAT = "image_generation_tool_heartbeat", // Custom tool packets CUSTOM_TOOL_START = "custom_tool_start", @@ -76,6 +77,9 @@ interface GeneratedImage { file_id: string; url: string; revised_prompt: string; + width?: number | null; + height?: number | null; + shape?: string | null; } export interface ImageGenerationToolStart extends BaseObj { @@ -87,6 +91,13 @@ export interface ImageGenerationToolDelta extends BaseObj { images: GeneratedImage[]; } +export interface ImageGenerationToolHeartbeat extends BaseObj { + type: "image_generation_tool_heartbeat"; + shape?: string | null; + width?: number | null; + height?: number | null; +} + // Custom Tool Packets export interface CustomToolStart extends BaseObj { type: "custom_tool_start"; @@ -137,6 +148,7 @@ export type SearchToolObj = SearchToolStart | SearchToolDelta | SectionEnd; export type ImageGenerationToolObj = | ImageGenerationToolStart | ImageGenerationToolDelta + | ImageGenerationToolHeartbeat | SectionEnd; export type CustomToolObj = CustomToolStart | CustomToolDelta | SectionEnd; export type NewToolObj = SearchToolObj | ImageGenerationToolObj | CustomToolObj;