1313from threading import Thread
1414from typing import List
1515import gradio
16+ from torchvision .transforms import ToTensor
1617import urllib3
1718from PIL import Image
1819from modules import processing
2425from scripts .spartan .control_net import pack_control_net
2526from scripts .spartan .shared import logger
2627from scripts .spartan .ui import UI
27- from scripts .spartan .world import World , State
28+ from scripts .spartan .world import World , State , Job
2829
2930old_sigint_handler = signal .getsignal (signal .SIGINT )
3031old_sigterm_handler = signal .getsignal (signal .SIGTERM )
@@ -61,7 +62,7 @@ def show(self, is_img2img):
6162 return scripts .AlwaysVisible
6263
6364 def ui (self , is_img2img ):
64- extension_ui = UI (world = self .world )
65+ extension_ui = UI (world = self .world , is_img2img = is_img2img )
6566 # root, api_exposed = extension_ui.create_ui()
6667 components = extension_ui .create_ui ()
6768
@@ -71,77 +72,61 @@ def ui(self, is_img2img):
7172 # return some components that should be exposed to the api
7273 return components
7374
74- def add_to_gallery (self , processed , p ):
75- """adds generated images to the image gallery after waiting for all workers to finish"""
75+ def api_to_internal (self , job ) -> ([], [], [], [], [] ):
76+ # takes worker response received from api and returns parsed objects in internal sdwui format. E.g. all_seeds
7677
77- def processed_inject_image (image , info_index , save_path_override = None , grid = False , response = None ):
78- image_params : json = response ['parameters' ]
79- image_info_post : json = json .loads (response ["info" ]) # image info known after processing
80- num_response_images = image_params ["batch_size" ] * image_params ["n_iter" ]
81-
82- seed = None
83- subseed = None
84- negative_prompt = None
85- pos_prompt = None
78+ image_params : json = job .worker .response ['parameters' ]
79+ image_info_post : json = json .loads (job .worker .response ["info" ]) # image info known after processing
80+ all_seeds , all_subseeds , all_negative_prompts , all_prompts , images = [], [], [], [], []
8681
82+ for i in range (len (job .worker .response ["images" ])):
8783 try :
88- if num_response_images > 1 :
89- seed = image_info_post ['all_seeds' ][info_index ]
90- subseed = image_info_post ['all_subseeds' ][info_index ]
91- negative_prompt = image_info_post ['all_negative_prompts' ][info_index ]
92- pos_prompt = image_info_post ['all_prompts' ][info_index ]
93- else :
94- seed = image_info_post ['seed' ]
95- subseed = image_info_post ['subseed' ]
96- negative_prompt = image_info_post ['negative_prompt' ]
97- pos_prompt = image_info_post ['prompt' ]
84+ if image_params [ "batch_size" ] * image_params [ "n_iter" ] > 1 :
85+ all_seeds . append ( image_info_post ['all_seeds' ][i ])
86+ all_subseeds . append ( image_info_post ['all_subseeds' ][i ])
87+ all_negative_prompts . append ( image_info_post ['all_negative_prompts' ][i ])
88+ all_prompts . append ( image_info_post ['all_prompts' ][i ])
89+ else : # only a single image received
90+ all_seeds . append ( image_info_post ['seed' ])
91+ all_subseeds . append ( image_info_post ['subseed' ])
92+ all_negative_prompts . append ( image_info_post ['negative_prompt' ])
93+ all_prompts . append ( image_info_post ['prompt' ])
9894 except IndexError :
99- # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
100- logger .debug (f"Image at index { i } for '{ job .worker .label } ' was missing some post-generation data" )
101- processed_inject_image (image = image , info_index = 0 , response = response )
102- return
103-
104- processed .all_seeds .append (seed )
105- processed .all_subseeds .append (subseed )
106- processed .all_negative_prompts .append (negative_prompt )
107- processed .all_prompts .append (pos_prompt )
108- processed .images .append (image ) # actual received image
109-
110- # generate info-text string
111- # modules.ui_common -> update_generation_info renders to html below gallery
112- images_per_batch = p .n_iter * p .batch_size
113- # zero-indexed position of image in total batch (so including master results)
114- true_image_pos = len (processed .images ) - 1
115- num_remote_images = images_per_batch * p .batch_size
116- if p .n_iter > 1 : # if splitting by batch count
117- num_remote_images *= p .n_iter - 1
95+ # # like with controlnet masks, there isn't always full post-gen info, so we use the first images'
96+ # logger.debug(f"Image at index {info_index} for '{job.worker.label}' was missing some post-generation data")
97+ # self.processed_inject_image(image=image, info_index=0, job=job, p=p)
98+ # return
99+ logger .critical (f"Image at index { i } for '{ job .worker .label } ' was missing some post-generation data" )
100+ continue
118101
119- logger .debug (f"image { true_image_pos + 1 } /{ self .world .p .batch_size * p .n_iter } , "
120- f"info-index: { info_index } " )
102+ # parse image
103+ image_bytes = base64 .b64decode (job .worker .response ["images" ][i ])
104+ image = Image .open (io .BytesIO (image_bytes ))
105+ transform = ToTensor ()
106+ images .append (transform (image ))
121107
122- if self .world .thin_client_mode :
123- p .all_negative_prompts = processed .all_negative_prompts
108+ return all_seeds , all_subseeds , all_negative_prompts , all_prompts , images
124109
125- try :
126- info_text = image_info_post [ 'infotexts' ][ i ]
127- except IndexError :
128- if not grid :
129- logger . warning ( f"image { true_image_pos + 1 } was missing info-text" )
130- info_text = processed . infotexts [ 0 ]
131- info_text += f", Worker Label: { job . worker . label } "
132- processed . infotexts . append ( info_text )
133-
134- # automatically save received image to local disk if desired
135- if cmd_opts . distributed_remotes_autosave :
136- save_image (
137- image = image ,
138- path = p . outpath_samples if save_path_override is None else save_path_override ,
139- basename = "" ,
140- seed = seed ,
141- prompt = pos_prompt ,
142- info = info_text ,
143- extension = opts . samples_format
144- )
110+ def inject_job ( self , job : Job , p , pp ) :
111+ """Adds the work completed by one Job via its worker response to the processing and postprocessing objects"""
112+ all_seeds , all_subseeds , all_negative_prompts , all_prompts , images = self . api_to_internal ( job )
113+
114+ p . seeds . extend ( all_seeds )
115+ p . subseeds . extend ( all_subseeds )
116+ p . negative_prompts . extend ( all_negative_prompts )
117+ p . prompts . extend ( all_prompts )
118+
119+ num_local = self . world . p . n_iter * self . world . p . batch_size + ( opts . return_grid - self . world . thin_client_mode )
120+ num_injected = len ( pp . images ) - self . world . p . batch_size
121+ for i , image in enumerate ( images ):
122+ # modules.ui_common -> update_generation_info renders to html below gallery
123+ gallery_index = num_local + num_injected + i # zero-indexed point of image in total gallery
124+ job . gallery_map . append ( gallery_index ) # so we know where to edit infotext
125+ pp . images . append ( image )
126+ logger . debug ( f"image { gallery_index + 1 + self . world . thin_client_mode } / { self . world . num_gallery () } " )
127+
128+ def update_gallery ( self , pp , p ):
129+ """adds all remotely generated images to the image gallery after waiting for all workers to finish"""
145130
146131 # get master ipm by estimating based on worker speed
147132 master_elapsed = time .time () - self .master_start
@@ -158,8 +143,7 @@ def processed_inject_image(image, info_index, save_path_override=None, grid=Fals
158143 logger .debug ("all worker request threads returned" )
159144 webui_state .textinfo = "Distributed - injecting images"
160145
161- # some worker which we know has a good response that we can use for generating the grid
162- donor_worker = None
146+ received_images = False
163147 for job in self .world .jobs :
164148 if job .worker .response is None or job .batch_size < 1 or job .worker .master :
165149 continue
@@ -170,8 +154,7 @@ def processed_inject_image(image, info_index, save_path_override=None, grid=Fals
170154 if (job .batch_size * p .n_iter ) < len (images ):
171155 logger .debug (f"requested { job .batch_size } image(s) from '{ job .worker .label } ', got { len (images )} " )
172156
173- if donor_worker is None :
174- donor_worker = job .worker
157+ received_images = True
175158 except KeyError :
176159 if job .batch_size > 0 :
177160 logger .warning (f"Worker '{ job .worker .label } ' had no images" )
@@ -185,41 +168,27 @@ def processed_inject_image(image, info_index, save_path_override=None, grid=Fals
185168 logger .exception (e )
186169 continue
187170
188- # visibly add work from workers to the image gallery
189- for i in range (0 , len (images )):
190- image_bytes = base64 .b64decode (images [i ])
191- image = Image .open (io .BytesIO (image_bytes ))
171+ # adding the images in
172+ self .inject_job (job , p , pp )
192173
193- # inject image
194- processed_inject_image (image = image , info_index = i , response = job .worker .response )
195-
196- if donor_worker is None :
174+ # TODO fix controlnet masks returned via api having no generation info
175+ if received_images is False :
197176 logger .critical ("couldn't collect any responses, the extension will have no effect" )
198177 return
199178
200- # generate and inject grid
201- if opts .return_grid and len (processed .images ) > 1 :
202- grid = image_grid (processed .images , len (processed .images ))
203- processed_inject_image (
204- image = grid ,
205- info_index = 0 ,
206- save_path_override = p .outpath_grids ,
207- grid = True ,
208- response = donor_worker .response
209- )
210-
211- # cleanup after we're doing using all the responses
212- for worker in self .world .get_workers ():
213- worker .response = None
214-
215- p .batch_size = len (processed .images )
179+ p .batch_size = len (pp .images )
180+ webui_state .textinfo = ""
216181 return
217182
218183 # p's type is
219184 # "modules.processing.StableDiffusionProcessing*"
220185 def before_process (self , p , * args ):
221- if not self .world .enabled :
222- logger .debug ("extension is disabled" )
186+ is_img2img = getattr (p , 'init_images' , False )
187+ if is_img2img and self .world .enabled_i2i is False :
188+ logger .debug ("extension is disabled for i2i" )
189+ return
190+ elif not is_img2img and self .world .enabled is False :
191+ logger .debug ("extension is disabled for t2i" )
223192 return
224193 self .world .update (p )
225194
@@ -234,6 +203,14 @@ def before_process(self, p, *args):
234203 continue
235204 title = script .title ()
236205
206+ if title == "ADetailer" :
207+ adetailer_args = p .script_args [script .args_from :script .args_to ]
208+
209+ # InputAccordion main toggle, skip img2img toggle
210+ if adetailer_args [0 ] and adetailer_args [1 ]:
211+ logger .debug (f"adetailer is skipping img2img, returning control to wui" )
212+ return
213+
237214 # check for supported scripts
238215 if title == "ControlNet" :
239216 # grab all controlnet units
@@ -346,18 +323,34 @@ def before_process(self, p, *args):
346323 p .batch_size = self .world .master_job ().batch_size
347324 self .master_start = time .time ()
348325
349- # generate images assigned to local machine
350- p .do_not_save_grid = True # don't generate grid from master as we are doing this later.
351326 self .runs_since_init += 1
352327 return
353328
354- def postprocess (self , p , processed , * args ):
355- if not self .world .enabled :
329+ def postprocess_batch_list (self , p , pp , * args , ** kwargs ):
330+ if not self .world .thin_client_mode and p .n_iter != kwargs ['batch_number' ] + 1 : # skip if not the final batch
331+ return
332+
333+ is_img2img = getattr (p , 'init_images' , False )
334+ if is_img2img and self .world .enabled_i2i is False :
335+ return
336+ elif not is_img2img and self .world .enabled is False :
356337 return
357338
358339 if self .master_start is not None :
359- self .add_to_gallery (p = p , processed = processed )
340+ self .update_gallery (p = p , pp = pp )
341+
360342
343+ def postprocess (self , p , processed , * args ):
344+ for job in self .world .jobs :
345+ if job .worker .response is not None :
346+ for i , v in enumerate (job .gallery_map ):
347+ infotext = json .loads (job .worker .response ['info' ])['infotexts' ][i ]
348+ infotext += f", Worker Label: { job .worker .label } "
349+ processed .infotexts [v ] = infotext
350+
351+ # cleanup
352+ for worker in self .world .get_workers ():
353+ worker .response = None
361354 # restore process_images_inner if it was monkey-patched
362355 processing .process_images_inner = self .original_process_images_inner
363356 # save any dangling state to prevent load_config in next iteration overwriting it
0 commit comments