@@ -61,6 +61,7 @@ def multi_diffusion_denoise(
61
61
# full noise. Investigate the history of why this got commented out.
62
62
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
63
63
latents = self .scheduler .add_noise (latents , noise , batched_init_timestep )
64
+ assert isinstance (latents , torch .Tensor ) # For static type checking.
64
65
65
66
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
66
67
# cropping into regions.
@@ -122,29 +123,52 @@ def multi_diffusion_denoise(
122
123
control_data = region_conditioning .control_data ,
123
124
)
124
125
125
- # Store the results from the region.
126
- # If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
127
- # affected tiles to achieve the target overlap.
126
+ # Build a region_weight matrix that applies gradient blending to the edges of the region.
128
127
region = region_conditioning .region
129
- top_adjustment = max (0 , region .overlap .top - target_overlap )
130
- left_adjustment = max (0 , region .overlap .left - target_overlap )
131
- region_height_slice = slice (region .coords .top + top_adjustment , region .coords .bottom )
132
- region_width_slice = slice (region .coords .left + left_adjustment , region .coords .right )
133
- merged_latents [:, :, region_height_slice , region_width_slice ] += step_output .prev_sample [
134
- :, :, top_adjustment :, left_adjustment :
135
- ]
136
- # For now, we treat every region as having the same weight.
137
- merged_latents_weights [:, :, region_height_slice , region_width_slice ] += 1.0
128
+ _ , _ , region_height , region_width = step_output .prev_sample .shape
129
+ region_weight = torch .ones (
130
+ (1 , 1 , region_height , region_width ),
131
+ dtype = latents .dtype ,
132
+ device = latents .device ,
133
+ )
134
+ if region .overlap .left > 0 :
135
+ left_grad = torch .linspace (
136
+ 0 , 1 , region .overlap .left , device = latents .device , dtype = latents .dtype
137
+ ).view ((1 , 1 , 1 , - 1 ))
138
+ region_weight [:, :, :, : region .overlap .left ] *= left_grad
139
+ if region .overlap .top > 0 :
140
+ top_grad = torch .linspace (
141
+ 0 , 1 , region .overlap .top , device = latents .device , dtype = latents .dtype
142
+ ).view ((1 , 1 , - 1 , 1 ))
143
+ region_weight [:, :, : region .overlap .top , :] *= top_grad
144
+ if region .overlap .right > 0 :
145
+ right_grad = torch .linspace (
146
+ 1 , 0 , region .overlap .right , device = latents .device , dtype = latents .dtype
147
+ ).view ((1 , 1 , 1 , - 1 ))
148
+ region_weight [:, :, :, - region .overlap .right :] *= right_grad
149
+ if region .overlap .bottom > 0 :
150
+ bottom_grad = torch .linspace (
151
+ 1 , 0 , region .overlap .bottom , device = latents .device , dtype = latents .dtype
152
+ ).view ((1 , 1 , - 1 , 1 ))
153
+ region_weight [:, :, - region .overlap .bottom :, :] *= bottom_grad
154
+
155
+ # Update the merged results with the region results.
156
+ merged_latents [
157
+ :, :, region .coords .top : region .coords .bottom , region .coords .left : region .coords .right
158
+ ] += step_output .prev_sample * region_weight
159
+ merged_latents_weights [
160
+ :, :, region .coords .top : region .coords .bottom , region .coords .left : region .coords .right
161
+ ] += region_weight
138
162
139
163
pred_orig_sample = getattr (step_output , "pred_original_sample" , None )
140
164
if pred_orig_sample is not None :
141
165
# If one region has pred_original_sample, then we can assume that all regions will have it, because
142
166
# they all use the same scheduler.
143
167
if merged_pred_original is None :
144
168
merged_pred_original = torch .zeros_like (latents )
145
- merged_pred_original [:, :, region_height_slice , region_width_slice ] += pred_orig_sample [
146
- :, :, top_adjustment :, left_adjustment :
147
- ]
169
+ merged_pred_original [
170
+ :, :, region . coords . top : region . coords . bottom , region . coords . left : region . coords . right
171
+ ] += pred_orig_sample
148
172
149
173
# Normalize the merged results.
150
174
latents = torch .where (merged_latents_weights > 0 , merged_latents / merged_latents_weights , merged_latents )
0 commit comments