Skip to content

Commit e16faa6

Browse files
RyanJDickhipsterusername
authored andcommitted
Add gradient blending to tile seams in MultiDiffusion.
1 parent 97a7f51 commit e16faa6

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
175175
_, _, latent_height, latent_width = latents.shape
176176

177177
# Calculate the tile locations to cover the latent-space image.
178+
# TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider:
179+
# - How much overlap 'context' to provide for each denoising step.
180+
# - How much overlap to use during merging/blending.
181+
# - Should we 'jitter' the tile locations in each step so that the seams are in different places?
178182
tiles = calc_tiles_min_overlap(
179183
image_height=latent_height,
180184
image_width=latent_width,

invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def multi_diffusion_denoise(
6161
# full noise. Investigate the history of why this got commented out.
6262
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
6363
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
64+
assert isinstance(latents, torch.Tensor) # For static type checking.
6465

6566
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
6667
# cropping into regions.
@@ -122,29 +123,52 @@ def multi_diffusion_denoise(
122123
control_data=region_conditioning.control_data,
123124
)
124125

125-
# Store the results from the region.
126-
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
127-
# affected tiles to achieve the target overlap.
126+
# Build a region_weight matrix that applies gradient blending to the edges of the region.
128127
region = region_conditioning.region
129-
top_adjustment = max(0, region.overlap.top - target_overlap)
130-
left_adjustment = max(0, region.overlap.left - target_overlap)
131-
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
132-
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
133-
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
134-
:, :, top_adjustment:, left_adjustment:
135-
]
136-
# For now, we treat every region as having the same weight.
137-
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
128+
_, _, region_height, region_width = step_output.prev_sample.shape
129+
region_weight = torch.ones(
130+
(1, 1, region_height, region_width),
131+
dtype=latents.dtype,
132+
device=latents.device,
133+
)
134+
if region.overlap.left > 0:
135+
left_grad = torch.linspace(
136+
0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
137+
).view((1, 1, 1, -1))
138+
region_weight[:, :, :, : region.overlap.left] *= left_grad
139+
if region.overlap.top > 0:
140+
top_grad = torch.linspace(
141+
0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
142+
).view((1, 1, -1, 1))
143+
region_weight[:, :, : region.overlap.top, :] *= top_grad
144+
if region.overlap.right > 0:
145+
right_grad = torch.linspace(
146+
1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
147+
).view((1, 1, 1, -1))
148+
region_weight[:, :, :, -region.overlap.right :] *= right_grad
149+
if region.overlap.bottom > 0:
150+
bottom_grad = torch.linspace(
151+
1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
152+
).view((1, 1, -1, 1))
153+
region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad
154+
155+
# Update the merged results with the region results.
156+
merged_latents[
157+
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
158+
] += step_output.prev_sample * region_weight
159+
merged_latents_weights[
160+
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
161+
] += region_weight
138162

139163
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
140164
if pred_orig_sample is not None:
141165
# If one region has pred_original_sample, then we can assume that all regions will have it, because
142166
# they all use the same scheduler.
143167
if merged_pred_original is None:
144168
merged_pred_original = torch.zeros_like(latents)
145-
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
146-
:, :, top_adjustment:, left_adjustment:
147-
]
169+
merged_pred_original[
170+
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
171+
] += pred_orig_sample
148172

149173
# Normalize the merged results.
150174
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)

0 commit comments

Comments
 (0)