1- import logging
21import sys
32import tqdm
4- import dspy
53import signal
4+ import logging
65import threading
76import traceback
87import contextlib
98
9+ from contextvars import copy_context
1010from tqdm .contrib .logging import logging_redirect_tqdm
1111from concurrent .futures import ThreadPoolExecutor , as_completed
1212
13-
1413logger = logging .getLogger (__name__ )
1514
1615
@@ -23,6 +22,8 @@ def __init__(
2322 provide_traceback = False ,
2423 compare_results = False ,
2524 ):
25+ """Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""
26+
2627 self .num_threads = num_threads
2728 self .disable_progress_bar = disable_progress_bar
2829 self .max_errors = max_errors
@@ -33,34 +34,18 @@ def __init__(
3334 self .error_lock = threading .Lock ()
3435 self .cancel_jobs = threading .Event ()
3536
36-
3737 def execute (self , function , data ):
3838 wrapped_function = self ._wrap_function (function )
3939 if self .num_threads == 1 :
40- return self ._execute_single_thread (wrapped_function , data )
40+ return self ._execute_isolated_single_thread (wrapped_function , data )
4141 else :
4242 return self ._execute_multi_thread (wrapped_function , data )
4343
44-
4544 def _wrap_function (self , function ):
46- # Wrap the function with threading context and error handling
47- def wrapped (item , parent_id = None ):
48- thread_stacks = dspy .settings .stack_by_thread
49- current_thread_id = threading .get_ident ()
50- creating_new_thread = current_thread_id not in thread_stacks
51-
52- assert creating_new_thread or threading .get_ident () == dspy .settings .main_tid
53-
54- if creating_new_thread :
55- # If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack?
56- if parent_id and parent_id in thread_stacks :
57- thread_stacks [current_thread_id ] = list (thread_stacks [parent_id ])
58- else :
59- thread_stacks [current_thread_id ] = list (dspy .settings .main_stack )
60-
61- # TODO: Consider the behavior below.
62- # import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1]))
63-
45+ # Wrap the function with error handling
46+ def wrapped (item ):
47+ if self .cancel_jobs .is_set ():
48+ return None
6449 try :
6550 return function (item )
6651 except Exception as e :
@@ -79,45 +64,53 @@ def wrapped(item, parent_id=None):
7964 f"Error processing item { item } : { e } . Set `provide_traceback=True` to see the stack trace."
8065 )
8166 return None
82- finally :
83- if creating_new_thread :
84- del thread_stacks [threading .get_ident ()]
8567 return wrapped
8668
87-
88- def _execute_single_thread (self , function , data ):
69+ def _execute_isolated_single_thread (self , function , data ):
8970 results = []
9071 pbar = tqdm .tqdm (
9172 total = len (data ),
9273 dynamic_ncols = True ,
9374 disable = self .disable_progress_bar ,
94- file = sys .stdout ,
75+ file = sys .stdout
9576 )
77+
9678 for item in data :
9779 with logging_redirect_tqdm ():
9880 if self .cancel_jobs .is_set ():
9981 break
100- result = function (item )
82+
83+ # Create an isolated context for each task
84+ task_ctx = copy_context ()
85+ result = task_ctx .run (function , item )
10186 results .append (result )
87+
10288 if self .compare_results :
10389 # Assumes score is the last element of the result tuple
104- self ._update_progress (pbar , sum ([r [- 1 ] for r in results if r is not None ]), len ([r for r in data if r is not None ]))
90+ self ._update_progress (
91+ pbar ,
92+ sum ([r [- 1 ] for r in results if r is not None ]),
93+ len ([r for r in data if r is not None ]),
94+ )
10595 else :
10696 self ._update_progress (pbar , len (results ), len (data ))
97+
10798 pbar .close ()
99+
108100 if self .cancel_jobs .is_set ():
109101 logger .warning ("Execution was cancelled due to errors." )
110102 raise Exception ("Execution was cancelled due to errors." )
111- return results
112103
104+ return results
113105
114106 def _update_progress (self , pbar , nresults , ntotal ):
115107 if self .compare_results :
116- pbar .set_description (f"Average Metric: { nresults :.2f} / { ntotal } ({ round (100 * nresults / ntotal , 1 ) if ntotal > 0 else 0 } %)" )
108+ percentage = round (100 * nresults / ntotal , 1 ) if ntotal > 0 else 0
109+ pbar .set_description (f"Average Metric: { nresults :.2f} / { ntotal } ({ percentage } %)" )
117110 else :
118111 pbar .set_description (f"Processed { nresults } / { ntotal } examples" )
119- pbar .update ()
120112
113+ pbar .update ()
121114
122115 def _execute_multi_thread (self , function , data ):
123116 results = [None ] * len (data ) # Pre-allocate results list to maintain order
@@ -132,6 +125,7 @@ def interrupt_handler_manager():
132125 def interrupt_handler (sig , frame ):
133126 self .cancel_jobs .set ()
134127 logger .warning ("Received SIGINT. Cancelling execution." )
128+ # Re-raise the signal to allow default behavior
135129 default_handler (sig , frame )
136130
137131 signal .signal (signal .SIGINT , interrupt_handler )
@@ -143,37 +137,53 @@ def interrupt_handler(sig, frame):
143137 # If not in the main thread, skip setting signal handlers
144138 yield
145139
146- def cancellable_function (index_item , parent_id = None ):
140+ def cancellable_function (index_item ):
147141 index , item = index_item
148142 if self .cancel_jobs .is_set ():
149143 return index , job_cancelled
150- return index , function (item , parent_id )
151-
152- parent_id = threading .get_ident () if threading .current_thread () is not threading .main_thread () else None
144+ return index , function (item )
153145
154146 with ThreadPoolExecutor (max_workers = self .num_threads ) as executor , interrupt_handler_manager ():
155- futures = {executor .submit (cancellable_function , pair , parent_id ): pair for pair in enumerate (data )}
147+ futures = {}
148+ for pair in enumerate (data ):
149+ # Capture the context for each task
150+ task_ctx = copy_context ()
151+ future = executor .submit (task_ctx .run , cancellable_function , pair )
152+ futures [future ] = pair
153+
156154 pbar = tqdm .tqdm (
157155 total = len (data ),
158156 dynamic_ncols = True ,
159157 disable = self .disable_progress_bar ,
160- file = sys .stdout ,
158+ file = sys .stdout
161159 )
162160
163161 for future in as_completed (futures ):
164162 index , result = future .result ()
165-
163+
166164 if result is job_cancelled :
167165 continue
166+
168167 results [index ] = result
169168
170169 if self .compare_results :
171170 # Assumes score is the last element of the result tuple
172- self ._update_progress (pbar , sum ([r [- 1 ] for r in results if r is not None ]), len ([r for r in results if r is not None ]))
171+ self ._update_progress (
172+ pbar ,
173+ sum ([r [- 1 ] for r in results if r is not None ]),
174+ len ([r for r in results if r is not None ]),
175+ )
173176 else :
174- self ._update_progress (pbar , len ([r for r in results if r is not None ]), len (data ))
177+ self ._update_progress (
178+ pbar ,
179+ len ([r for r in results if r is not None ]),
180+ len (data ),
181+ )
182+
175183 pbar .close ()
184+
176185 if self .cancel_jobs .is_set ():
177186 logger .warning ("Execution was cancelled due to errors." )
178187 raise Exception ("Execution was cancelled due to errors." )
188+
179189 return results
0 commit comments