Skip to content

[Break Glass] Queue Agent fix #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lib/addons/s3CSIDriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ export class s3CSIDriverAddOn extends blueprints.addons.HelmAddOn {
const cluster = clusterInfo.cluster;
const serviceAccount = cluster.addServiceAccount('s3-csi-driver-sa', {
name: 's3-csi-driver-sa',
namespace: this.options.namespace,
identityType: eks.IdentityType.POD_IDENTITY
namespace: this.options.namespace
});

// new IAM policy to grand access to S3 bucket
Expand Down
2 changes: 1 addition & 1 deletion lib/runtime/sdRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export const defaultProps: blueprints.addons.HelmAddOnProps & SDRuntimeAddOnProp
name: 'sdRuntimeAddOn',
namespace: 'sdruntime',
release: 'sdruntime',
version: '1.1.1',
version: '1.1.3',
repository: 'oci://public.ecr.aws/bingjiao/charts/sd-on-eks',
values: {},
type: "sdwebui"
Expand Down
266 changes: 170 additions & 96 deletions src/backend/queue_agent/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,86 @@
import signal
import sys
import uuid
import time
import functools

import boto3
from botocore.exceptions import EndpointConnectionError
from aws_xray_sdk.core import patch_all, xray_recorder
from aws_xray_sdk.core.models.trace_header import TraceHeader
from modules import s3_action, sns_action, sqs_action
from runtimes import comfyui, sdwebui

patch_all()

# Logging configuration
# Initialize logging first so we can log X-Ray initialization attempts
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)

# Configure the queue-agent logger only once
logger = logging.getLogger("queue-agent")
logger.propagate = False
logger.setLevel(os.environ.get('LOGLEVEL', 'INFO').upper())

# Remove any existing handlers to prevent duplicate logs
if logger.handlers:
logger.handlers.clear()

# Add a single handler
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)

# Set current logger as global
logger = logging.getLogger("queue-agent")
# Check if X-Ray is manually disabled via environment variable
DISABLE_XRAY = os.environ.get('DISABLE_XRAY', 'false').lower() == 'true'
if DISABLE_XRAY:
logger.info("X-Ray tracing manually disabled via DISABLE_XRAY environment variable")
xray_enabled = False
else:
# Try to initialize X-Ray SDK with retries, as the daemon might be starting up
MAX_XRAY_INIT_ATTEMPTS = 5
XRAY_RETRY_DELAY = 3 # seconds
xray_enabled = False

for attempt in range(MAX_XRAY_INIT_ATTEMPTS):
try:
logger.info(f"Attempting to initialize X-Ray SDK (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
patch_all()
xray_enabled = True
logger.info("X-Ray SDK initialized successfully")
break
except EndpointConnectionError:
logger.warning(f"Could not connect to X-Ray daemon (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
if attempt < MAX_XRAY_INIT_ATTEMPTS - 1:
logger.info(f"Retrying in {XRAY_RETRY_DELAY} seconds...")
time.sleep(XRAY_RETRY_DELAY)
except Exception as e:
logger.warning(f"Error initializing X-Ray: {str(e)} (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
if attempt < MAX_XRAY_INIT_ATTEMPTS - 1:
logger.info(f"Retrying in {XRAY_RETRY_DELAY} seconds...")
time.sleep(XRAY_RETRY_DELAY)

if not xray_enabled:
logger.warning("X-Ray initialization failed after all attempts. Tracing will be disabled.")

# Create a decorator for safe X-Ray instrumentation
def safe_xray_capture(name):
"""Decorator that safely applies X-Ray instrumentation if available"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if xray_enabled:
try:
# Try to use X-Ray instrumentation
with xray_recorder.in_segment(name):
return func(*args, **kwargs)
except Exception as e:
logger.warning(f"X-Ray instrumentation failed for {name}: {str(e)}")
# Fall back to non-instrumented execution
return func(*args, **kwargs)
else:
# X-Ray is disabled, just call the function directly
return func(*args, **kwargs)
return wrapper
return decorator

# Get base environment variable
aws_default_region = os.getenv("AWS_DEFAULT_REGION")
Expand Down Expand Up @@ -104,101 +162,115 @@ def main():
received_messages = sqs_action.receive_messages(queue, 1, SQS_WAIT_TIME_SECONDS)

for message in received_messages:
with xray_recorder.in_segment(runtime_name+"-queue-agent") as segment:
# Retrieve x-ray trace header from SQS message
if "AWSTraceHeader" in message.attributes.keys():
traceHeaderStr = message.attributes['AWSTraceHeader']
sqsTraceHeader = TraceHeader.from_header_str(traceHeaderStr)
# Update current segment to link with SQS
segment.trace_id = sqsTraceHeader.root
segment.parent_id = sqsTraceHeader.parent
segment.sampled = sqsTraceHeader.sampled

# Process received message
# Process with X-Ray if enabled, otherwise just process the message directly
if xray_enabled:
try:
payload = json.loads(json.loads(message.body)['Message'])
metadata = payload["metadata"]
task_id = metadata["id"]

logger.info(f"Received task {task_id}, processing")

if "prefix" in metadata.keys():
if metadata["prefix"][-1] == '/':
prefix = metadata["prefix"] + str(task_id)
else:
prefix = metadata["prefix"] + "/" + str(task_id)
else:
prefix = str(task_id)

if "tasktype" in metadata.keys():
tasktype = metadata["tasktype"]

if "context" in metadata.keys():
context = metadata["context"]
else:
context = {}

body = payload["content"]
logger.debug(body)
with xray_recorder.in_segment(runtime_name+"-queue-agent") as segment:
# Retrieve x-ray trace header from SQS message
if "AWSTraceHeader" in message.attributes.keys():
traceHeaderStr = message.attributes['AWSTraceHeader']
sqsTraceHeader = TraceHeader.from_header_str(traceHeaderStr)
# Update current segment to link with SQS
segment.trace_id = sqsTraceHeader.root
segment.parent_id = sqsTraceHeader.parent
segment.sampled = sqsTraceHeader.sampled

# Process the message within the X-Ray segment
process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)
except Exception as e:
logger.error(f"Error parsing message: {e}, skipping")
logger.debug(payload)
sqs_action.delete_message(message)
continue

if (exp_callback_when_running.lower() == "true"):
sns_response = {"runtime": runtime_name,
'id': task_id,
'status': "running",
'context': context}

sns_action.publish_message(topic, json.dumps(sns_response))

# Start handling message
response = {}
logger.error(f"Error with X-Ray tracing: {str(e)}. Processing message without tracing.")
process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)
else:
# Process without X-Ray tracing
process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)

def process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model=None):
"""Process a single SQS message"""
# Process received message
try:
payload = json.loads(json.loads(message.body)['Message'])
metadata = payload["metadata"]
task_id = metadata["id"]

logger.info(f"Received task {task_id}, processing")

if "prefix" in metadata.keys():
if metadata["prefix"][-1] == '/':
prefix = metadata["prefix"] + str(task_id)
else:
prefix = metadata["prefix"] + "/" + str(task_id)
else:
prefix = str(task_id)

if "tasktype" in metadata.keys():
tasktype = metadata["tasktype"]

if "context" in metadata.keys():
context = metadata["context"]
else:
context = {}

body = payload["content"]
logger.debug(body)
except Exception as e:
logger.error(f"Error parsing message: {e}, skipping")
logger.debug(payload)
sqs_action.delete_message(message)
return

if (exp_callback_when_running.lower() == "true"):
sns_response = {"runtime": runtime_name,
'id': task_id,
'status': "running",
'context': context}

sns_action.publish_message(topic, json.dumps(sns_response))

# Start handling message
response = {}

try:
if runtime_type == "sdwebui":
response = sdwebui.handler(api_base_url, tasktype, task_id, body, dynamic_sd_model)

if runtime_type == "comfyui":
response = comfyui.handler(api_base_url, task_id, body)
except Exception as e:
logger.error(f"Error calling handler for task {task_id}: {str(e)}")
response = {
"success": False,
"image": [],
"content": '{"code": 500, "error": "Runtime handler failed"}'
}

result = []
rand = str(uuid.uuid4())[0:4]

if response["success"]:
idx = 0
if len(response["image"]) > 0:
for i in response["image"]:
idx += 1
result.append(s3_action.upload_file(i, s3_bucket, prefix, str(task_id)+"-"+rand+"-"+str(idx)))

output_url = s3_action.upload_file(response["content"], s3_bucket, prefix, str(task_id)+"-"+rand, ".out")

if response["success"]:
status = "completed"
else:
status = "failed"

try:
if runtime_type == "sdwebui":
response = sdwebui.handler(api_base_url, tasktype, task_id, body, dynamic_sd_model)
sns_response = {"runtime": runtime_name,
'id': task_id,
'result': response["success"],
'status': status,
'image_url': result,
'output_url': output_url,
'context': context}

if runtime_type == "comfyui":
response = comfyui.handler(api_base_url, task_id, body)
except Exception as e:
logger.error(f"Error calling handler for task {task_id}: {str(e)}")
response = {
"success": False,
"image": [],
"content": '{"code": 500, "error": "Runtime handler failed"}'
}

result = []
rand = str(uuid.uuid4())[0:4]

if response["success"]:
idx = 0
if len(response["image"]) > 0:
for i in response["image"]:
idx += 1
result.append(s3_action.upload_file(i, s3_bucket, prefix, str(task_id)+"-"+rand+"-"+str(idx)))

output_url = s3_action.upload_file(response["content"], s3_bucket, prefix, str(task_id)+"-"+rand, ".out")

if response["success"]:
status = "completed"
else:
status = "failed"

sns_response = {"runtime": runtime_name,
'id': task_id,
'result': response["success"],
'status': status,
'image_url': result,
'output_url': output_url,
'context': context}

# Put response handler to SNS and delete message
sns_action.publish_message(topic, json.dumps(sns_response))
sqs_action.delete_message(message)
# Put response handler to SNS and delete message
sns_action.publish_message(topic, json.dumps(sns_response))
sqs_action.delete_message(message)

def print_env() -> None:
logger.info(f'AWS_DEFAULT_REGION={aws_default_region}')
Expand All @@ -207,6 +279,8 @@ def print_env() -> None:
logger.info(f'S3_BUCKET={s3_bucket}')
logger.info(f'RUNTIME_TYPE={runtime_type}')
logger.info(f'RUNTIME_NAME={runtime_name}')
logger.info(f'X-Ray Tracing: {"Disabled" if DISABLE_XRAY else "Enabled"}')
logger.info(f'X-Ray Status: {"Active" if xray_enabled else "Inactive"}')

def signalHandler(signum, frame):
global shutdown
Expand Down
19 changes: 17 additions & 2 deletions src/backend/queue_agent/src/runtimes/comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,30 @@
import time
import traceback
import urllib.parse
import urllib.request
import uuid
from typing import Optional, Dict, List, Any, Union

import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
from aws_xray_sdk.core import xray_recorder
from modules import http_action

logger = logging.getLogger("queue-agent")

# Import the safe_xray_capture decorator from main module
try:
from src.main import safe_xray_capture, xray_enabled
except ImportError:
try:
# Try alternative import path
from ..main import safe_xray_capture, xray_enabled
except ImportError:
# Fallback if import fails - create a simple pass-through decorator
logger.warning("Failed to import safe_xray_capture from main, using fallback")
def safe_xray_capture(name):
def decorator(func):
return func
return decorator
xray_enabled = False

# Constants for websocket reconnection
MAX_RECONNECT_ATTEMPTS = 5
RECONNECT_DELAY = 2 # seconds
Expand Down Expand Up @@ -324,6 +338,7 @@ def handler(api_base_url: str, task_id: str, payload: dict) -> dict:

return response

@safe_xray_capture('comfyui-pipeline')
def invoke_pipeline(api_base_url: str, body) -> str:
cf = comfyuiCaller()
cf.setUrl(api_base_url)
Expand Down
Loading
Loading