Skip to content

Commit 47d859a

Browse files
authored
Merge pull request #56 from aws-solutions-library-samples/xray-fix
[Break Glass] Add error handling and fallback to X-ray
2 parents 3bd163f + 5898ebe commit 47d859a

File tree

12 files changed

+243
-138
lines changed

12 files changed

+243
-138
lines changed

lib/addons/s3CSIDriver.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ export class s3CSIDriverAddOn extends blueprints.addons.HelmAddOn {
3030
const cluster = clusterInfo.cluster;
3131
const serviceAccount = cluster.addServiceAccount('s3-csi-driver-sa', {
3232
name: 's3-csi-driver-sa',
33-
namespace: this.options.namespace,
34-
identityType: eks.IdentityType.POD_IDENTITY
33+
namespace: this.options.namespace
3534
});
3635

3736
// new IAM policy to grand access to S3 bucket

lib/runtime/sdRuntime.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ export const defaultProps: blueprints.addons.HelmAddOnProps & SDRuntimeAddOnProp
2828
name: 'sdRuntimeAddOn',
2929
namespace: 'sdruntime',
3030
release: 'sdruntime',
31-
version: '1.1.1',
31+
version: '1.1.3',
3232
repository: 'oci://public.ecr.aws/bingjiao/charts/sd-on-eks',
3333
values: {},
3434
type: "sdwebui"

src/backend/queue_agent/src/main.py

Lines changed: 170 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,86 @@
77
import signal
88
import sys
99
import uuid
10+
import time
11+
import functools
1012

1113
import boto3
14+
from botocore.exceptions import EndpointConnectionError
1215
from aws_xray_sdk.core import patch_all, xray_recorder
1316
from aws_xray_sdk.core.models.trace_header import TraceHeader
1417
from modules import s3_action, sns_action, sqs_action
1518
from runtimes import comfyui, sdwebui
1619

17-
patch_all()
18-
19-
# Logging configuration
20+
# Initialize logging first so we can log X-Ray initialization attempts
2021
logging.basicConfig()
2122
logging.getLogger().setLevel(logging.ERROR)
2223

24+
# Configure the queue-agent logger only once
2325
logger = logging.getLogger("queue-agent")
2426
logger.propagate = False
2527
logger.setLevel(os.environ.get('LOGLEVEL', 'INFO').upper())
28+
29+
# Remove any existing handlers to prevent duplicate logs
30+
if logger.handlers:
31+
logger.handlers.clear()
32+
33+
# Add a single handler
2634
handler = logging.StreamHandler(sys.stdout)
2735
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
2836
logger.addHandler(handler)
2937

30-
# Set current logger as global
31-
logger = logging.getLogger("queue-agent")
38+
# Check if X-Ray is manually disabled via environment variable
39+
DISABLE_XRAY = os.environ.get('DISABLE_XRAY', 'false').lower() == 'true'
40+
if DISABLE_XRAY:
41+
logger.info("X-Ray tracing manually disabled via DISABLE_XRAY environment variable")
42+
xray_enabled = False
43+
else:
44+
# Try to initialize X-Ray SDK with retries, as the daemon might be starting up
45+
MAX_XRAY_INIT_ATTEMPTS = 5
46+
XRAY_RETRY_DELAY = 3 # seconds
47+
xray_enabled = False
48+
49+
for attempt in range(MAX_XRAY_INIT_ATTEMPTS):
50+
try:
51+
logger.info(f"Attempting to initialize X-Ray SDK (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
52+
patch_all()
53+
xray_enabled = True
54+
logger.info("X-Ray SDK initialized successfully")
55+
break
56+
except EndpointConnectionError:
57+
logger.warning(f"Could not connect to X-Ray daemon (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
58+
if attempt < MAX_XRAY_INIT_ATTEMPTS - 1:
59+
logger.info(f"Retrying in {XRAY_RETRY_DELAY} seconds...")
60+
time.sleep(XRAY_RETRY_DELAY)
61+
except Exception as e:
62+
logger.warning(f"Error initializing X-Ray: {str(e)} (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
63+
if attempt < MAX_XRAY_INIT_ATTEMPTS - 1:
64+
logger.info(f"Retrying in {XRAY_RETRY_DELAY} seconds...")
65+
time.sleep(XRAY_RETRY_DELAY)
66+
67+
if not xray_enabled:
68+
logger.warning("X-Ray initialization failed after all attempts. Tracing will be disabled.")
69+
70+
# Create a decorator for safe X-Ray instrumentation
71+
def safe_xray_capture(name):
72+
"""Decorator that safely applies X-Ray instrumentation if available"""
73+
def decorator(func):
74+
@functools.wraps(func)
75+
def wrapper(*args, **kwargs):
76+
if xray_enabled:
77+
try:
78+
# Try to use X-Ray instrumentation
79+
with xray_recorder.in_segment(name):
80+
return func(*args, **kwargs)
81+
except Exception as e:
82+
logger.warning(f"X-Ray instrumentation failed for {name}: {str(e)}")
83+
# Fall back to non-instrumented execution
84+
return func(*args, **kwargs)
85+
else:
86+
# X-Ray is disabled, just call the function directly
87+
return func(*args, **kwargs)
88+
return wrapper
89+
return decorator
3290

3391
# Get base environment variable
3492
aws_default_region = os.getenv("AWS_DEFAULT_REGION")
@@ -104,101 +162,115 @@ def main():
104162
received_messages = sqs_action.receive_messages(queue, 1, SQS_WAIT_TIME_SECONDS)
105163

106164
for message in received_messages:
107-
with xray_recorder.in_segment(runtime_name+"-queue-agent") as segment:
108-
# Retrieve x-ray trace header from SQS message
109-
if "AWSTraceHeader" in message.attributes.keys():
110-
traceHeaderStr = message.attributes['AWSTraceHeader']
111-
sqsTraceHeader = TraceHeader.from_header_str(traceHeaderStr)
112-
# Update current segment to link with SQS
113-
segment.trace_id = sqsTraceHeader.root
114-
segment.parent_id = sqsTraceHeader.parent
115-
segment.sampled = sqsTraceHeader.sampled
116-
117-
# Process received message
165+
# Process with X-Ray if enabled, otherwise just process the message directly
166+
if xray_enabled:
118167
try:
119-
payload = json.loads(json.loads(message.body)['Message'])
120-
metadata = payload["metadata"]
121-
task_id = metadata["id"]
122-
123-
logger.info(f"Received task {task_id}, processing")
124-
125-
if "prefix" in metadata.keys():
126-
if metadata["prefix"][-1] == '/':
127-
prefix = metadata["prefix"] + str(task_id)
128-
else:
129-
prefix = metadata["prefix"] + "/" + str(task_id)
130-
else:
131-
prefix = str(task_id)
132-
133-
if "tasktype" in metadata.keys():
134-
tasktype = metadata["tasktype"]
135-
136-
if "context" in metadata.keys():
137-
context = metadata["context"]
138-
else:
139-
context = {}
140-
141-
body = payload["content"]
142-
logger.debug(body)
168+
with xray_recorder.in_segment(runtime_name+"-queue-agent") as segment:
169+
# Retrieve x-ray trace header from SQS message
170+
if "AWSTraceHeader" in message.attributes.keys():
171+
traceHeaderStr = message.attributes['AWSTraceHeader']
172+
sqsTraceHeader = TraceHeader.from_header_str(traceHeaderStr)
173+
# Update current segment to link with SQS
174+
segment.trace_id = sqsTraceHeader.root
175+
segment.parent_id = sqsTraceHeader.parent
176+
segment.sampled = sqsTraceHeader.sampled
177+
178+
# Process the message within the X-Ray segment
179+
process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)
143180
except Exception as e:
144-
logger.error(f"Error parsing message: {e}, skipping")
145-
logger.debug(payload)
146-
sqs_action.delete_message(message)
147-
continue
148-
149-
if (exp_callback_when_running.lower() == "true"):
150-
sns_response = {"runtime": runtime_name,
151-
'id': task_id,
152-
'status': "running",
153-
'context': context}
154-
155-
sns_action.publish_message(topic, json.dumps(sns_response))
156-
157-
# Start handling message
158-
response = {}
181+
logger.error(f"Error with X-Ray tracing: {str(e)}. Processing message without tracing.")
182+
process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)
183+
else:
184+
# Process without X-Ray tracing
185+
process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)
186+
187+
def process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model=None):
188+
"""Process a single SQS message"""
189+
# Process received message
190+
try:
191+
payload = json.loads(json.loads(message.body)['Message'])
192+
metadata = payload["metadata"]
193+
task_id = metadata["id"]
194+
195+
logger.info(f"Received task {task_id}, processing")
196+
197+
if "prefix" in metadata.keys():
198+
if metadata["prefix"][-1] == '/':
199+
prefix = metadata["prefix"] + str(task_id)
200+
else:
201+
prefix = metadata["prefix"] + "/" + str(task_id)
202+
else:
203+
prefix = str(task_id)
204+
205+
if "tasktype" in metadata.keys():
206+
tasktype = metadata["tasktype"]
207+
208+
if "context" in metadata.keys():
209+
context = metadata["context"]
210+
else:
211+
context = {}
212+
213+
body = payload["content"]
214+
logger.debug(body)
215+
except Exception as e:
216+
logger.error(f"Error parsing message: {e}, skipping")
217+
logger.debug(payload)
218+
sqs_action.delete_message(message)
219+
return
220+
221+
if (exp_callback_when_running.lower() == "true"):
222+
sns_response = {"runtime": runtime_name,
223+
'id': task_id,
224+
'status': "running",
225+
'context': context}
226+
227+
sns_action.publish_message(topic, json.dumps(sns_response))
228+
229+
# Start handling message
230+
response = {}
231+
232+
try:
233+
if runtime_type == "sdwebui":
234+
response = sdwebui.handler(api_base_url, tasktype, task_id, body, dynamic_sd_model)
235+
236+
if runtime_type == "comfyui":
237+
response = comfyui.handler(api_base_url, task_id, body)
238+
except Exception as e:
239+
logger.error(f"Error calling handler for task {task_id}: {str(e)}")
240+
response = {
241+
"success": False,
242+
"image": [],
243+
"content": '{"code": 500, "error": "Runtime handler failed"}'
244+
}
245+
246+
result = []
247+
rand = str(uuid.uuid4())[0:4]
248+
249+
if response["success"]:
250+
idx = 0
251+
if len(response["image"]) > 0:
252+
for i in response["image"]:
253+
idx += 1
254+
result.append(s3_action.upload_file(i, s3_bucket, prefix, str(task_id)+"-"+rand+"-"+str(idx)))
255+
256+
output_url = s3_action.upload_file(response["content"], s3_bucket, prefix, str(task_id)+"-"+rand, ".out")
257+
258+
if response["success"]:
259+
status = "completed"
260+
else:
261+
status = "failed"
159262

160-
try:
161-
if runtime_type == "sdwebui":
162-
response = sdwebui.handler(api_base_url, tasktype, task_id, body, dynamic_sd_model)
263+
sns_response = {"runtime": runtime_name,
264+
'id': task_id,
265+
'result': response["success"],
266+
'status': status,
267+
'image_url': result,
268+
'output_url': output_url,
269+
'context': context}
163270

164-
if runtime_type == "comfyui":
165-
response = comfyui.handler(api_base_url, task_id, body)
166-
except Exception as e:
167-
logger.error(f"Error calling handler for task {task_id}: {str(e)}")
168-
response = {
169-
"success": False,
170-
"image": [],
171-
"content": '{"code": 500, "error": "Runtime handler failed"}'
172-
}
173-
174-
result = []
175-
rand = str(uuid.uuid4())[0:4]
176-
177-
if response["success"]:
178-
idx = 0
179-
if len(response["image"]) > 0:
180-
for i in response["image"]:
181-
idx += 1
182-
result.append(s3_action.upload_file(i, s3_bucket, prefix, str(task_id)+"-"+rand+"-"+str(idx)))
183-
184-
output_url = s3_action.upload_file(response["content"], s3_bucket, prefix, str(task_id)+"-"+rand, ".out")
185-
186-
if response["success"]:
187-
status = "completed"
188-
else:
189-
status = "failed"
190-
191-
sns_response = {"runtime": runtime_name,
192-
'id': task_id,
193-
'result': response["success"],
194-
'status': status,
195-
'image_url': result,
196-
'output_url': output_url,
197-
'context': context}
198-
199-
# Put response handler to SNS and delete message
200-
sns_action.publish_message(topic, json.dumps(sns_response))
201-
sqs_action.delete_message(message)
271+
# Put response handler to SNS and delete message
272+
sns_action.publish_message(topic, json.dumps(sns_response))
273+
sqs_action.delete_message(message)
202274

203275
def print_env() -> None:
204276
logger.info(f'AWS_DEFAULT_REGION={aws_default_region}')
@@ -207,6 +279,8 @@ def print_env() -> None:
207279
logger.info(f'S3_BUCKET={s3_bucket}')
208280
logger.info(f'RUNTIME_TYPE={runtime_type}')
209281
logger.info(f'RUNTIME_NAME={runtime_name}')
282+
logger.info(f'X-Ray Tracing: {"Disabled" if DISABLE_XRAY else "Enabled"}')
283+
logger.info(f'X-Ray Status: {"Active" if xray_enabled else "Inactive"}')
210284

211285
def signalHandler(signum, frame):
212286
global shutdown

src/backend/queue_agent/src/runtimes/comfyui.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,30 @@
66
import time
77
import traceback
88
import urllib.parse
9-
import urllib.request
109
import uuid
1110
from typing import Optional, Dict, List, Any, Union
1211

1312
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
14-
from aws_xray_sdk.core import xray_recorder
1513
from modules import http_action
1614

1715
logger = logging.getLogger("queue-agent")
1816

17+
# Import the safe_xray_capture decorator from main module
18+
try:
19+
from src.main import safe_xray_capture, xray_enabled
20+
except ImportError:
21+
try:
22+
# Try alternative import path
23+
from ..main import safe_xray_capture, xray_enabled
24+
except ImportError:
25+
# Fallback if import fails - create a simple pass-through decorator
26+
logger.warning("Failed to import safe_xray_capture from main, using fallback")
27+
def safe_xray_capture(name):
28+
def decorator(func):
29+
return func
30+
return decorator
31+
xray_enabled = False
32+
1933
# Constants for websocket reconnection
2034
MAX_RECONNECT_ATTEMPTS = 5
2135
RECONNECT_DELAY = 2 # seconds
@@ -324,6 +338,7 @@ def handler(api_base_url: str, task_id: str, payload: dict) -> dict:
324338

325339
return response
326340

341+
@safe_xray_capture('comfyui-pipeline')
327342
def invoke_pipeline(api_base_url: str, body) -> str:
328343
cf = comfyuiCaller()
329344
cf.setUrl(api_base_url)

0 commit comments

Comments
 (0)