From f1a5f9aa806e6a04096028c7696a7cba2cf616aa Mon Sep 17 00:00:00 2001 From: yubingjiaocn Date: Wed, 2 Apr 2025 09:02:17 +0000 Subject: [PATCH 1/5] Add retry for ComfyUI websocket connection --- src/backend/queue_agent/src/main.py | 16 +- .../queue_agent/src/runtimes/comfyui.py | 350 +++++++++++++----- 2 files changed, 269 insertions(+), 97 deletions(-) diff --git a/src/backend/queue_agent/src/main.py b/src/backend/queue_agent/src/main.py index d40b497..b81a297 100644 --- a/src/backend/queue_agent/src/main.py +++ b/src/backend/queue_agent/src/main.py @@ -157,11 +157,19 @@ def main(): # Start handling message response = {} - if runtime_type == "sdwebui": - response = sdwebui.handler(api_base_url, tasktype, task_id, body, dynamic_sd_model) + try: + if runtime_type == "sdwebui": + response = sdwebui.handler(api_base_url, tasktype, task_id, body, dynamic_sd_model) - if runtime_type == "comfyui": - response = comfyui.handler(api_base_url, task_id, body) + if runtime_type == "comfyui": + response = comfyui.handler(api_base_url, task_id, body) + except Exception as e: + logger.error(f"Error calling handler for task {task_id}: {str(e)}") + response = { + "success": False, + "image": [], + "content": '{"code": 500, "error": "Runtime handler failed"}' + } result = [] rand = str(uuid.uuid4())[0:4] diff --git a/src/backend/queue_agent/src/runtimes/comfyui.py b/src/backend/queue_agent/src/runtimes/comfyui.py index 3540159..6906595 100644 --- a/src/backend/queue_agent/src/runtimes/comfyui.py +++ b/src/backend/queue_agent/src/runtimes/comfyui.py @@ -8,12 +8,18 @@ import urllib.parse import urllib.request import uuid +from typing import Optional, Dict, List, Any, Union import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) from aws_xray_sdk.core import xray_recorder +from modules import http_action logger = logging.getLogger("queue-agent") +# Constants for websocket reconnection +MAX_RECONNECT_ATTEMPTS = 5 +RECONNECT_DELAY = 2 # seconds + def singleton(cls): _instance = {} @@ -27,100 +33,241 @@ def inner(): class comfyuiCaller(object): def __init__(self): - # self.wss = wsClient() self.wss = websocket.WebSocket() self.client_id = str(uuid.uuid4()) - pass + self.api_base_url = None + self.connected = False def setUrl(self, api_base_url:str): self.api_base_url = api_base_url def wss_connect(self): - self.wss.connect("ws://{}/ws?clientId={}".format(self.api_base_url, self.client_id)) - # self.wss.Connect("ws://{}/ws?clientId={}".format(self.api_base_url, self.client_id)) + """Connect to websocket with reconnection logic""" + if self.connected: + return True + + attempts = 0 + while attempts < MAX_RECONNECT_ATTEMPTS: + try: + logger.info(f"Connecting to websocket (attempt {attempts+1}/{MAX_RECONNECT_ATTEMPTS})") + self.wss.connect("ws://{}/ws?clientId={}".format(self.api_base_url, self.client_id)) + self.connected = True + logger.info("Successfully connected to websocket") + return True + except Exception as e: + attempts += 1 + logger.warning(f"Failed to connect to websocket: {str(e)}") + if attempts < MAX_RECONNECT_ATTEMPTS: + logger.info(f"Retrying in {RECONNECT_DELAY} seconds...") + time.sleep(RECONNECT_DELAY) + else: + logger.error("Max reconnection attempts reached") + raise ConnectionError(f"Failed to connect to ComfyUI websocket after {MAX_RECONNECT_ATTEMPTS} attempts") from e + + return False + + def wss_recv(self) -> Optional[str]: + """Receive data from websocket with reconnection logic""" + attempts = 0 + while attempts < MAX_RECONNECT_ATTEMPTS: + try: + return self.wss.recv() + except websocket.WebSocketConnectionClosedException: + attempts += 1 + logger.warning(f"WebSocket connection closed, attempting to reconnect (attempt {attempts}/{MAX_RECONNECT_ATTEMPTS})...") + self.connected = False + + if attempts < MAX_RECONNECT_ATTEMPTS: + if self.wss_connect(): + logger.info("Reconnected successfully, retrying receive operation") + continue + else: + logger.warning(f"Failed to reconnect, waiting {RECONNECT_DELAY} seconds before retry...") + time.sleep(RECONNECT_DELAY) + else: + logger.error("Max reconnection attempts reached in wss_recv") + return None + except Exception as e: + attempts += 1 + logger.error(f"Error receiving data from websocket: {str(e)}") + self.connected = False + + if attempts < MAX_RECONNECT_ATTEMPTS: + logger.info(f"Waiting {RECONNECT_DELAY} seconds before retry...") + time.sleep(RECONNECT_DELAY) + if self.wss_connect(): + logger.info("Reconnected successfully, retrying receive operation") + continue + else: + logger.error("Max reconnection attempts reached in wss_recv") + return None + + return None def get_history(self, prompt_id): - with urllib.request.urlopen("http://{}/history/{}".format(self.api_base_url, prompt_id)) as response: - return json.loads(response.read()) + try: + url = f"http://{self.api_base_url}/history/{prompt_id}" + # Use the http_action module with built-in retry logic + return http_action.do_invocations(url) + except Exception as e: + logger.error(f"Error in get_history: {str(e)}") + return {} def queue_prompt(self, prompt): try: p = {"prompt": prompt, "client_id": self.client_id} - data = json.dumps(p).encode('utf-8') - logger.debug(data) - req = urllib.request.Request("http://{}/prompt".format(self.api_base_url), data=data) - output = urllib.request.urlopen(req) - return json.loads(output.read()) + url = f"http://{self.api_base_url}/prompt" + + # Use the http_action module with built-in retry logic + response = http_action.do_invocations(url, p) + return response except Exception as e: - logger.error(e) + logger.error(f"Error in queue_prompt: {str(e)}") return None - def get_image(self, filename, subfolder, folder_type): - data = {"filename": filename, "subfolder": subfolder, "type": folder_type} - url_values = urllib.parse.urlencode(data) - with urllib.request.urlopen("http://{}/view?{}".format(self.api_base_url, url_values)) as response: - return response.read() + def get_image(self, filename, subfolder, folder_type): + try: + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + url = f"http://{self.api_base_url}/view?{url_values}" + + # Use http_action.get which returns bytes directly + return http_action.get(url) + except Exception as e: + logger.error(f"Error getting image {filename}: {str(e)}") + return b'' # Return empty bytes on error def track_progress(self, prompt, prompt_id): logger.info("Task received, prompt ID:" + prompt_id) node_ids = list(prompt.keys()) finished_nodes = [] + max_errors = 5 + error_count = 0 while True: - out = self.wss.recv() - if isinstance(out, str): - message = json.loads(out) - logger.debug(out) - if message['type'] == 'progress': - data = message['data'] - current_step = data['value'] - logger.info(f"In K-Sampler -> Step: {current_step} of: {data['max']}") - if message['type'] == 'execution_cached': - data = message['data'] - for itm in data['nodes']: - if itm not in finished_nodes: - finished_nodes.append(itm) - logger.info(f"Progess: {len(finished_nodes)} / {len(node_ids)} tasks done") - if message['type'] == 'executing': - data = message['data'] - if data['node'] not in finished_nodes: - finished_nodes.append(data['node']) - logger.info(f"Progess: {len(finished_nodes)} / {len(node_ids)} tasks done") - - if data['node'] is None and data['prompt_id'] == prompt_id: - break #Execution is done - else: - continue - return + try: + out = self.wss_recv() # Using our new method with reconnection logic + if out is None: + error_count += 1 + logger.warning(f"Failed to receive data from websocket (error {error_count}/{max_errors})") + if error_count >= max_errors: + logger.error("Too many errors receiving websocket data, aborting track_progress") + return False + time.sleep(1) + continue + + error_count = 0 # Reset error count on successful receive + + if isinstance(out, str): + try: + message = json.loads(out) + logger.debug(out) + if message['type'] == 'progress': + data = message['data'] + current_step = data['value'] + logger.info(f"In K-Sampler -> Step: {current_step} of: {data['max']}") + if message['type'] == 'execution_cached': + data = message['data'] + for itm in data['nodes']: + if itm not in finished_nodes: + finished_nodes.append(itm) + logger.info(f"Progress: {len(finished_nodes)} / {len(node_ids)} tasks done") + if message['type'] == 'executing': + data = message['data'] + if data['node'] not in finished_nodes: + finished_nodes.append(data['node']) + logger.info(f"Progress: {len(finished_nodes)} / {len(node_ids)} tasks done") + + if data['node'] is None and data['prompt_id'] == prompt_id: + return True # Execution is done successfully + except json.JSONDecodeError as e: + logger.warning(f"Error parsing websocket message: {str(e)}, skipping message") + continue + except KeyError as e: + logger.warning(f"Missing key in websocket message: {str(e)}, skipping message") + continue + else: + continue + except Exception as e: + error_count += 1 + logger.warning(f"Unexpected error in track_progress: {str(e)} (error {error_count}/{max_errors})") + if error_count >= max_errors: + logger.error("Too many errors in track_progress, aborting") + return False + time.sleep(1) + + return True def get_images(self, prompt): - output = self.queue_prompt(prompt) - if output is None: - raise("internal error") - prompt_id = output['prompt_id'] - output_images = {} - self.track_progress(prompt, prompt_id) - - history = self.get_history(prompt_id)[prompt_id] - for o in history['outputs']: - for node_id in history['outputs']: - node_output = history['outputs'][node_id] - # image branch - if 'images' in node_output: - images_output = [] - for image in node_output['images']: - image_data = self.get_image(image['filename'], image['subfolder'], image['type']) - images_output.append(image_data) - output_images[node_id] = images_output - # video branch - if 'videos' in node_output: - videos_output = [] - for video in node_output['videos']: - video_data = self.get_image(video['filename'], video['subfolder'], video['type']) - videos_output.append(video_data) - output_images[node_id] = videos_output - - return output_images + max_retries = 3 + retry_count = 0 + + while retry_count < max_retries: + try: + output = self.queue_prompt(prompt) + if output is None: + raise RuntimeError("Failed to queue prompt - internal error") + + prompt_id = output['prompt_id'] + output_images = {} + + self.track_progress(prompt, prompt_id) + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + # image branch + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output + # video branch + if 'videos' in node_output: + videos_output = [] + for video in node_output['videos']: + video_data = self.get_image(video['filename'], video['subfolder'], video['type']) + videos_output.append(video_data) + output_images[node_id] = videos_output + + # If we got here, everything worked + return output_images + + except websocket.WebSocketConnectionClosedException as e: + retry_count += 1 + logger.warning(f"WebSocket connection closed during processing (attempt {retry_count}/{max_retries})") + + # Try to reconnect before retrying + self.connected = False + if retry_count < max_retries: + logger.info("Attempting to reconnect websocket...") + if self.wss_connect(): + logger.info("Reconnected successfully, retrying operation") + time.sleep(1) # Small delay before retry + else: + logger.error("Failed to reconnect websocket") + else: + logger.error(f"Failed after {max_retries} attempts") + raise RuntimeError(f"Failed to process images after {max_retries} attempts") from e + + except Exception as e: + logger.error(f"Error processing images: {str(e)}") + retry_count += 1 + + # For non-websocket errors, we might still want to try reconnecting the websocket + if not self.connected and retry_count < max_retries: + logger.info("Attempting to reconnect websocket...") + self.wss_connect() + time.sleep(1) # Small delay before retry + else: + # If it's not a connection issue or we've tried enough times, re-raise + if retry_count >= max_retries: + raise + + # This should not be reached, but just in case + raise RuntimeError(f"Failed to process images after {max_retries} attempts") def parse_worflow(self, prompt_data): logger.debug(prompt_data) @@ -131,43 +278,60 @@ def check_readiness(api_base_url: str) -> bool: cf = comfyuiCaller() cf.setUrl(api_base_url) logger.info("Init health check... ") - while True: - try: - logger.info(f"Try to connect to ComfyUI backend {api_base_url} ... ") - cf.wss_connect() - break - except Exception as e: - time.sleep(1) - continue - logger.info(f"ComfyUI backend {api_base_url} connected. ") - return True + try: + logger.info(f"Try to connect to ComfyUI backend {api_base_url} ... ") + if cf.wss_connect(): + logger.info(f"ComfyUI backend {api_base_url} connected.") + return True + else: + logger.error(f"Failed to connect to ComfyUI backend {api_base_url}") + return False + except Exception as e: + logger.error(f"Error during health check: {str(e)}") + return False -def handler(api_base_url: str, task_id: str, payload: dict) -> str: - response = {} +def handler(api_base_url: str, task_id: str, payload: dict) -> dict: + response = { + "success": False, + "image": [], + "content": '{"code": 500}' + } try: logger.info(f"Processing pipeline task with ID: {task_id}") - images = invoke_pipeline(api_base_url, payload) - # write to s3 - imgOutputs = post_invocations(images) - logger.info(f"Received {len(imgOutputs)} images") - content = '{"code": 200}' - response["success"] = True - response["image"] = imgOutputs - response["content"] = content - logger.info(f"End process pipeline task with ID: {task_id}") + + # Attempt to invoke the pipeline + try: + images = invoke_pipeline(api_base_url, payload) + + # Process images if available + imgOutputs = post_invocations(images) + logger.info(f"Received {len(imgOutputs)} images") + + # Set success response + response["success"] = True + response["image"] = imgOutputs + response["content"] = '{"code": 200}' + logger.info(f"End process pipeline task with ID: {task_id}") + except Exception as e: + logger.error(f"Error processing pipeline: {str(e)}") + # Keep default failure response except Exception as e: - logger.error(f"Pipeline task with ID: {task_id} finished with error") + # This is a catch-all for any unexpected errors + logger.error(f"Unexpected error in handler for task ID {task_id}: {str(e)}") traceback.print_exc() - response["success"] = False - response["content"] = '{"code": 500}' return response def invoke_pipeline(api_base_url: str, body) -> str: cf = comfyuiCaller() cf.setUrl(api_base_url) + + # Ensure websocket connection is established before proceeding + if not cf.wss_connect(): + raise ConnectionError(f"Failed to establish websocket connection to {api_base_url}") + return cf.parse_worflow(body) def post_invocations(image): From ea0f9cc7f6098266ceb06061995d3f85a1c2b9bb Mon Sep 17 00:00:00 2001 From: yubingjiaocn Date: Wed, 2 Apr 2025 09:02:36 +0000 Subject: [PATCH 2/5] Fix model switch bug for sd web ui --- src/backend/queue_agent/src/runtimes/sdwebui.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/backend/queue_agent/src/runtimes/sdwebui.py b/src/backend/queue_agent/src/runtimes/sdwebui.py index ffc4913..965810c 100644 --- a/src/backend/queue_agent/src/runtimes/sdwebui.py +++ b/src/backend/queue_agent/src/runtimes/sdwebui.py @@ -199,17 +199,17 @@ def switch_model(api_base_url: str, name: str) -> str: # refresh then check from model list invoke_refresh_checkpoints(api_base_url) models = invoke_get_model_names(api_base_url) - try: - invoke_unload_checkpoints(api_base_url) - except HTTPError: - logger.info(f"No model is currently loaded.") if name in models: + try: + invoke_unload_checkpoints(api_base_url) + except HTTPError: + logger.info(f"No model is currently loaded. Loading new model... ") options = {} options["sd_model_checkpoint"] = name invoke_set_options(api_base_url, options) current_model_name = name else: - logger.error(f"Model {name} not found.") + logger.error(f"Model {name} not found, keeping current model.") return None return current_model_name From 107d2286aacab3c7dadda05b4b4aae243aaf27f6 Mon Sep 17 00:00:00 2001 From: yubingjiaocn Date: Thu, 3 Apr 2025 05:38:02 +0000 Subject: [PATCH 3/5] Bump queue agent to Python 3.12 --- src/backend/queue_agent/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/queue_agent/Dockerfile b/src/backend/queue_agent/Dockerfile index 71fe52c..083fe44 100644 --- a/src/backend/queue_agent/Dockerfile +++ b/src/backend/queue_agent/Dockerfile @@ -1,4 +1,4 @@ -FROM public.ecr.aws/docker/library/python:3.10-slim +FROM public.ecr.aws/docker/library/python:3.12-slim RUN apt-get update; apt-get install libmagic1 -y; rm -rf /var/lib/apt/lists/* From 87e61d3aca8c3b91c0f5e9d68f285e0a5e301abe Mon Sep 17 00:00:00 2001 From: yubingjiaocn Date: Wed, 9 Apr 2025 05:49:58 +0000 Subject: [PATCH 4/5] Add a check for current model to avoid accidently unload model --- src/backend/queue_agent/src/runtimes/sdwebui.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/backend/queue_agent/src/runtimes/sdwebui.py b/src/backend/queue_agent/src/runtimes/sdwebui.py index 965810c..27be50d 100644 --- a/src/backend/queue_agent/src/runtimes/sdwebui.py +++ b/src/backend/queue_agent/src/runtimes/sdwebui.py @@ -200,10 +200,12 @@ def switch_model(api_base_url: str, name: str) -> str: invoke_refresh_checkpoints(api_base_url) models = invoke_get_model_names(api_base_url) if name in models: - try: - invoke_unload_checkpoints(api_base_url) - except HTTPError: - logger.info(f"No model is currently loaded. Loading new model... ") + if ((current_model_name != None) or (current_model_name != "")): + logger.info(f"Model {current_model_name} is currently loaded, unloading... ") + try: + invoke_unload_checkpoints(api_base_url) + except HTTPError: + logger.info(f"No model is currently loaded. Loading new model... ") options = {} options["sd_model_checkpoint"] = name invoke_set_options(api_base_url, options) From 6ebe805ae67e63b508d0fdd4ae13412930813b20 Mon Sep 17 00:00:00 2001 From: yubingjiaocn Date: Wed, 9 Apr 2025 06:25:50 +0000 Subject: [PATCH 5/5] Bug fix --- src/backend/queue_agent/src/runtimes/sdwebui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/queue_agent/src/runtimes/sdwebui.py b/src/backend/queue_agent/src/runtimes/sdwebui.py index 27be50d..9f9578c 100644 --- a/src/backend/queue_agent/src/runtimes/sdwebui.py +++ b/src/backend/queue_agent/src/runtimes/sdwebui.py @@ -200,7 +200,7 @@ def switch_model(api_base_url: str, name: str) -> str: invoke_refresh_checkpoints(api_base_url) models = invoke_get_model_names(api_base_url) if name in models: - if ((current_model_name != None) or (current_model_name != "")): + if (current_model_name != None): logger.info(f"Model {current_model_name} is currently loaded, unloading... ") try: invoke_unload_checkpoints(api_base_url)