7
7
import signal
8
8
import sys
9
9
import uuid
10
+ import time
11
+ import functools
10
12
11
13
import boto3
14
+ from botocore .exceptions import EndpointConnectionError
12
15
from aws_xray_sdk .core import patch_all , xray_recorder
13
16
from aws_xray_sdk .core .models .trace_header import TraceHeader
14
17
from modules import s3_action , sns_action , sqs_action
15
18
from runtimes import comfyui , sdwebui
16
19
17
- patch_all ()
18
-
19
- # Logging configuration
20
+ # Initialize logging first so we can log X-Ray initialization attempts
20
21
logging .basicConfig ()
21
22
logging .getLogger ().setLevel (logging .ERROR )
22
23
24
+ # Configure the queue-agent logger only once
23
25
logger = logging .getLogger ("queue-agent" )
24
26
logger .propagate = False
25
27
logger .setLevel (os .environ .get ('LOGLEVEL' , 'INFO' ).upper ())
28
+
29
+ # Remove any existing handlers to prevent duplicate logs
30
+ if logger .handlers :
31
+ logger .handlers .clear ()
32
+
33
+ # Add a single handler
26
34
handler = logging .StreamHandler (sys .stdout )
27
35
handler .setFormatter (logging .Formatter ('%(asctime)s - %(levelname)s - %(message)s' ))
28
36
logger .addHandler (handler )
29
37
30
- # Set current logger as global
31
- logger = logging .getLogger ("queue-agent" )
38
+ # Check if X-Ray is manually disabled via environment variable
39
+ DISABLE_XRAY = os .environ .get ('DISABLE_XRAY' , 'false' ).lower () == 'true'
40
+ if DISABLE_XRAY :
41
+ logger .info ("X-Ray tracing manually disabled via DISABLE_XRAY environment variable" )
42
+ xray_enabled = False
43
+ else :
44
+ # Try to initialize X-Ray SDK with retries, as the daemon might be starting up
45
+ MAX_XRAY_INIT_ATTEMPTS = 5
46
+ XRAY_RETRY_DELAY = 3 # seconds
47
+ xray_enabled = False
48
+
49
+ for attempt in range (MAX_XRAY_INIT_ATTEMPTS ):
50
+ try :
51
+ logger .info (f"Attempting to initialize X-Ray SDK (attempt { attempt + 1 } /{ MAX_XRAY_INIT_ATTEMPTS } )" )
52
+ patch_all ()
53
+ xray_enabled = True
54
+ logger .info ("X-Ray SDK initialized successfully" )
55
+ break
56
+ except EndpointConnectionError :
57
+ logger .warning (f"Could not connect to X-Ray daemon (attempt { attempt + 1 } /{ MAX_XRAY_INIT_ATTEMPTS } )" )
58
+ if attempt < MAX_XRAY_INIT_ATTEMPTS - 1 :
59
+ logger .info (f"Retrying in { XRAY_RETRY_DELAY } seconds..." )
60
+ time .sleep (XRAY_RETRY_DELAY )
61
+ except Exception as e :
62
+ logger .warning (f"Error initializing X-Ray: { str (e )} (attempt { attempt + 1 } /{ MAX_XRAY_INIT_ATTEMPTS } )" )
63
+ if attempt < MAX_XRAY_INIT_ATTEMPTS - 1 :
64
+ logger .info (f"Retrying in { XRAY_RETRY_DELAY } seconds..." )
65
+ time .sleep (XRAY_RETRY_DELAY )
66
+
67
+ if not xray_enabled :
68
+ logger .warning ("X-Ray initialization failed after all attempts. Tracing will be disabled." )
69
+
70
+ # Create a decorator for safe X-Ray instrumentation
71
+ def safe_xray_capture (name ):
72
+ """Decorator that safely applies X-Ray instrumentation if available"""
73
+ def decorator (func ):
74
+ @functools .wraps (func )
75
+ def wrapper (* args , ** kwargs ):
76
+ if xray_enabled :
77
+ try :
78
+ # Try to use X-Ray instrumentation
79
+ with xray_recorder .in_segment (name ):
80
+ return func (* args , ** kwargs )
81
+ except Exception as e :
82
+ logger .warning (f"X-Ray instrumentation failed for { name } : { str (e )} " )
83
+ # Fall back to non-instrumented execution
84
+ return func (* args , ** kwargs )
85
+ else :
86
+ # X-Ray is disabled, just call the function directly
87
+ return func (* args , ** kwargs )
88
+ return wrapper
89
+ return decorator
32
90
33
91
# Get base environment variable
34
92
aws_default_region = os .getenv ("AWS_DEFAULT_REGION" )
@@ -104,101 +162,115 @@ def main():
104
162
received_messages = sqs_action .receive_messages (queue , 1 , SQS_WAIT_TIME_SECONDS )
105
163
106
164
for message in received_messages :
107
- with xray_recorder .in_segment (runtime_name + "-queue-agent" ) as segment :
108
- # Retrieve x-ray trace header from SQS message
109
- if "AWSTraceHeader" in message .attributes .keys ():
110
- traceHeaderStr = message .attributes ['AWSTraceHeader' ]
111
- sqsTraceHeader = TraceHeader .from_header_str (traceHeaderStr )
112
- # Update current segment to link with SQS
113
- segment .trace_id = sqsTraceHeader .root
114
- segment .parent_id = sqsTraceHeader .parent
115
- segment .sampled = sqsTraceHeader .sampled
116
-
117
- # Process received message
165
+ # Process with X-Ray if enabled, otherwise just process the message directly
166
+ if xray_enabled :
118
167
try :
119
- payload = json .loads (json .loads (message .body )['Message' ])
120
- metadata = payload ["metadata" ]
121
- task_id = metadata ["id" ]
122
-
123
- logger .info (f"Received task { task_id } , processing" )
124
-
125
- if "prefix" in metadata .keys ():
126
- if metadata ["prefix" ][- 1 ] == '/' :
127
- prefix = metadata ["prefix" ] + str (task_id )
128
- else :
129
- prefix = metadata ["prefix" ] + "/" + str (task_id )
130
- else :
131
- prefix = str (task_id )
132
-
133
- if "tasktype" in metadata .keys ():
134
- tasktype = metadata ["tasktype" ]
135
-
136
- if "context" in metadata .keys ():
137
- context = metadata ["context" ]
138
- else :
139
- context = {}
140
-
141
- body = payload ["content" ]
142
- logger .debug (body )
168
+ with xray_recorder .in_segment (runtime_name + "-queue-agent" ) as segment :
169
+ # Retrieve x-ray trace header from SQS message
170
+ if "AWSTraceHeader" in message .attributes .keys ():
171
+ traceHeaderStr = message .attributes ['AWSTraceHeader' ]
172
+ sqsTraceHeader = TraceHeader .from_header_str (traceHeaderStr )
173
+ # Update current segment to link with SQS
174
+ segment .trace_id = sqsTraceHeader .root
175
+ segment .parent_id = sqsTraceHeader .parent
176
+ segment .sampled = sqsTraceHeader .sampled
177
+
178
+ # Process the message within the X-Ray segment
179
+ process_message (message , topic , s3_bucket , runtime_type , runtime_name , api_base_url , dynamic_sd_model if runtime_type == "sdwebui" else None )
143
180
except Exception as e :
144
- logger .error (f"Error parsing message: { e } , skipping" )
145
- logger .debug (payload )
146
- sqs_action .delete_message (message )
147
- continue
148
-
149
- if (exp_callback_when_running .lower () == "true" ):
150
- sns_response = {"runtime" : runtime_name ,
151
- 'id' : task_id ,
152
- 'status' : "running" ,
153
- 'context' : context }
154
-
155
- sns_action .publish_message (topic , json .dumps (sns_response ))
156
-
157
- # Start handling message
158
- response = {}
181
+ logger .error (f"Error with X-Ray tracing: { str (e )} . Processing message without tracing." )
182
+ process_message (message , topic , s3_bucket , runtime_type , runtime_name , api_base_url , dynamic_sd_model if runtime_type == "sdwebui" else None )
183
+ else :
184
+ # Process without X-Ray tracing
185
+ process_message (message , topic , s3_bucket , runtime_type , runtime_name , api_base_url , dynamic_sd_model if runtime_type == "sdwebui" else None )
186
+
187
+ def process_message (message , topic , s3_bucket , runtime_type , runtime_name , api_base_url , dynamic_sd_model = None ):
188
+ """Process a single SQS message"""
189
+ # Process received message
190
+ try :
191
+ payload = json .loads (json .loads (message .body )['Message' ])
192
+ metadata = payload ["metadata" ]
193
+ task_id = metadata ["id" ]
194
+
195
+ logger .info (f"Received task { task_id } , processing" )
196
+
197
+ if "prefix" in metadata .keys ():
198
+ if metadata ["prefix" ][- 1 ] == '/' :
199
+ prefix = metadata ["prefix" ] + str (task_id )
200
+ else :
201
+ prefix = metadata ["prefix" ] + "/" + str (task_id )
202
+ else :
203
+ prefix = str (task_id )
204
+
205
+ if "tasktype" in metadata .keys ():
206
+ tasktype = metadata ["tasktype" ]
207
+
208
+ if "context" in metadata .keys ():
209
+ context = metadata ["context" ]
210
+ else :
211
+ context = {}
212
+
213
+ body = payload ["content" ]
214
+ logger .debug (body )
215
+ except Exception as e :
216
+ logger .error (f"Error parsing message: { e } , skipping" )
217
+ logger .debug (payload )
218
+ sqs_action .delete_message (message )
219
+ return
220
+
221
+ if (exp_callback_when_running .lower () == "true" ):
222
+ sns_response = {"runtime" : runtime_name ,
223
+ 'id' : task_id ,
224
+ 'status' : "running" ,
225
+ 'context' : context }
226
+
227
+ sns_action .publish_message (topic , json .dumps (sns_response ))
228
+
229
+ # Start handling message
230
+ response = {}
231
+
232
+ try :
233
+ if runtime_type == "sdwebui" :
234
+ response = sdwebui .handler (api_base_url , tasktype , task_id , body , dynamic_sd_model )
235
+
236
+ if runtime_type == "comfyui" :
237
+ response = comfyui .handler (api_base_url , task_id , body )
238
+ except Exception as e :
239
+ logger .error (f"Error calling handler for task { task_id } : { str (e )} " )
240
+ response = {
241
+ "success" : False ,
242
+ "image" : [],
243
+ "content" : '{"code": 500, "error": "Runtime handler failed"}'
244
+ }
245
+
246
+ result = []
247
+ rand = str (uuid .uuid4 ())[0 :4 ]
248
+
249
+ if response ["success" ]:
250
+ idx = 0
251
+ if len (response ["image" ]) > 0 :
252
+ for i in response ["image" ]:
253
+ idx += 1
254
+ result .append (s3_action .upload_file (i , s3_bucket , prefix , str (task_id )+ "-" + rand + "-" + str (idx )))
255
+
256
+ output_url = s3_action .upload_file (response ["content" ], s3_bucket , prefix , str (task_id )+ "-" + rand , ".out" )
257
+
258
+ if response ["success" ]:
259
+ status = "completed"
260
+ else :
261
+ status = "failed"
159
262
160
- try :
161
- if runtime_type == "sdwebui" :
162
- response = sdwebui .handler (api_base_url , tasktype , task_id , body , dynamic_sd_model )
263
+ sns_response = {"runtime" : runtime_name ,
264
+ 'id' : task_id ,
265
+ 'result' : response ["success" ],
266
+ 'status' : status ,
267
+ 'image_url' : result ,
268
+ 'output_url' : output_url ,
269
+ 'context' : context }
163
270
164
- if runtime_type == "comfyui" :
165
- response = comfyui .handler (api_base_url , task_id , body )
166
- except Exception as e :
167
- logger .error (f"Error calling handler for task { task_id } : { str (e )} " )
168
- response = {
169
- "success" : False ,
170
- "image" : [],
171
- "content" : '{"code": 500, "error": "Runtime handler failed"}'
172
- }
173
-
174
- result = []
175
- rand = str (uuid .uuid4 ())[0 :4 ]
176
-
177
- if response ["success" ]:
178
- idx = 0
179
- if len (response ["image" ]) > 0 :
180
- for i in response ["image" ]:
181
- idx += 1
182
- result .append (s3_action .upload_file (i , s3_bucket , prefix , str (task_id )+ "-" + rand + "-" + str (idx )))
183
-
184
- output_url = s3_action .upload_file (response ["content" ], s3_bucket , prefix , str (task_id )+ "-" + rand , ".out" )
185
-
186
- if response ["success" ]:
187
- status = "completed"
188
- else :
189
- status = "failed"
190
-
191
- sns_response = {"runtime" : runtime_name ,
192
- 'id' : task_id ,
193
- 'result' : response ["success" ],
194
- 'status' : status ,
195
- 'image_url' : result ,
196
- 'output_url' : output_url ,
197
- 'context' : context }
198
-
199
- # Put response handler to SNS and delete message
200
- sns_action .publish_message (topic , json .dumps (sns_response ))
201
- sqs_action .delete_message (message )
271
+ # Put response handler to SNS and delete message
272
+ sns_action .publish_message (topic , json .dumps (sns_response ))
273
+ sqs_action .delete_message (message )
202
274
203
275
def print_env () -> None :
204
276
logger .info (f'AWS_DEFAULT_REGION={ aws_default_region } ' )
@@ -207,6 +279,8 @@ def print_env() -> None:
207
279
logger .info (f'S3_BUCKET={ s3_bucket } ' )
208
280
logger .info (f'RUNTIME_TYPE={ runtime_type } ' )
209
281
logger .info (f'RUNTIME_NAME={ runtime_name } ' )
282
+ logger .info (f'X-Ray Tracing: { "Disabled" if DISABLE_XRAY else "Enabled" } ' )
283
+ logger .info (f'X-Ray Status: { "Active" if xray_enabled else "Inactive" } ' )
210
284
211
285
def signalHandler (signum , frame ):
212
286
global shutdown
0 commit comments