Skip to content

Commit d014dc9

Browse files
authored
Merge branch 'main' into stalker7779/modular_rescale_cfg
2 parents 1b359b5 + 154e8f6 commit d014dc9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1374
-240
lines changed
Lines changed: 137 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Callable
2+
13
import numpy as np
24
import torch
35
from PIL import Image
@@ -21,7 +23,7 @@
2123
from invokeai.backend.tiles.utils import TBLR, Tile
2224

2325

24-
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
26+
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.2.0")
2527
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
2628
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
2729

@@ -34,8 +36,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
3436
tile_size: int = InputField(
3537
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
3638
)
39+
scale: float = InputField(
40+
default=4.0,
41+
gt=0.0,
42+
le=16.0,
43+
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
44+
)
45+
fit_to_multiple_of_8: bool = InputField(
46+
default=False,
47+
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
48+
)
3749

38-
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
50+
@classmethod
51+
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
3952
return Tile(
4053
coords=TBLR(
4154
top=tile.coords.top * scale,
@@ -51,20 +64,22 @@ def _scale_tile(self, tile: Tile, scale: int) -> Tile:
5164
),
5265
)
5366

54-
@torch.inference_mode()
55-
def invoke(self, context: InvocationContext) -> ImageOutput:
56-
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
57-
# revisit this.
58-
image = context.images.get_pil(self.image.image_name, mode="RGB")
59-
67+
@classmethod
68+
def upscale_image(
69+
cls,
70+
image: Image.Image,
71+
tile_size: int,
72+
spandrel_model: SpandrelImageToImageModel,
73+
is_canceled: Callable[[], bool],
74+
) -> Image.Image:
6075
# Compute the image tiles.
61-
if self.tile_size > 0:
76+
if tile_size > 0:
6277
min_overlap = 20
6378
tiles = calc_tiles_min_overlap(
6479
image_height=image.height,
6580
image_width=image.width,
66-
tile_height=self.tile_size,
67-
tile_width=self.tile_size,
81+
tile_height=tile_size,
82+
tile_width=tile_size,
6883
min_overlap=min_overlap,
6984
)
7085
else:
@@ -85,60 +100,123 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
85100
# Prepare input image for inference.
86101
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
87102

88-
# Load the model.
89-
spandrel_model_info = context.models.load(self.image_to_image_model)
90-
91-
# Run the model on each tile.
92-
with spandrel_model_info as spandrel_model:
93-
assert isinstance(spandrel_model, SpandrelImageToImageModel)
103+
# Scale the tiles for re-assembling the final image.
104+
scale = spandrel_model.scale
105+
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
94106

95-
# Scale the tiles for re-assembling the final image.
96-
scale = spandrel_model.scale
97-
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
107+
# Prepare the output tensor.
108+
_, channels, height, width = image_tensor.shape
109+
output_tensor = torch.zeros(
110+
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
111+
)
98112

99-
# Prepare the output tensor.
100-
_, channels, height, width = image_tensor.shape
101-
output_tensor = torch.zeros(
102-
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
103-
)
113+
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
104114

105-
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
106-
107-
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
108-
# Exit early if the invocation has been canceled.
109-
if context.util.is_canceled():
110-
raise CanceledException
111-
112-
# Extract the current tile from the input tensor.
113-
input_tile = image_tensor[
114-
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
115-
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
116-
117-
# Run the model on the tile.
118-
output_tile = spandrel_model.run(input_tile)
119-
120-
# Convert the output tile into the output tensor's format.
121-
# (N, C, H, W) -> (C, H, W)
122-
output_tile = output_tile.squeeze(0)
123-
# (C, H, W) -> (H, W, C)
124-
output_tile = output_tile.permute(1, 2, 0)
125-
output_tile = output_tile.clamp(0, 1)
126-
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
127-
128-
# Merge the output tile into the output tensor.
129-
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
130-
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
131-
# it seems unnecessary, but we may find a need in the future.
132-
top_overlap = scaled_tile.overlap.top // 2
133-
left_overlap = scaled_tile.overlap.left // 2
134-
output_tensor[
135-
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
136-
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
137-
:,
138-
] = output_tile[top_overlap:, left_overlap:, :]
115+
# Run the model on each tile.
116+
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
117+
# Exit early if the invocation has been canceled.
118+
if is_canceled():
119+
raise CanceledException
120+
121+
# Extract the current tile from the input tensor.
122+
input_tile = image_tensor[
123+
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
124+
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
125+
126+
# Run the model on the tile.
127+
output_tile = spandrel_model.run(input_tile)
128+
129+
# Convert the output tile into the output tensor's format.
130+
# (N, C, H, W) -> (C, H, W)
131+
output_tile = output_tile.squeeze(0)
132+
# (C, H, W) -> (H, W, C)
133+
output_tile = output_tile.permute(1, 2, 0)
134+
output_tile = output_tile.clamp(0, 1)
135+
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
136+
137+
# Merge the output tile into the output tensor.
138+
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
139+
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
140+
# it seems unnecessary, but we may find a need in the future.
141+
top_overlap = scaled_tile.overlap.top // 2
142+
left_overlap = scaled_tile.overlap.left // 2
143+
output_tensor[
144+
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
145+
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
146+
:,
147+
] = output_tile[top_overlap:, left_overlap:, :]
139148

140149
# Convert the output tensor to a PIL image.
141150
np_image = output_tensor.detach().numpy().astype(np.uint8)
142151
pil_image = Image.fromarray(np_image)
152+
153+
return pil_image
154+
155+
@torch.inference_mode()
156+
def invoke(self, context: InvocationContext) -> ImageOutput:
157+
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
158+
# revisit this.
159+
image = context.images.get_pil(self.image.image_name, mode="RGB")
160+
161+
# Load the model.
162+
spandrel_model_info = context.models.load(self.image_to_image_model)
163+
164+
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
165+
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
166+
target_width = int(image.width * self.scale)
167+
target_height = int(image.height * self.scale)
168+
169+
# Do the upscaling.
170+
with spandrel_model_info as spandrel_model:
171+
assert isinstance(spandrel_model, SpandrelImageToImageModel)
172+
173+
# First pass of upscaling. Note: `pil_image` will be mutated.
174+
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
175+
176+
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
177+
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
178+
# to be considered an upscale model.
179+
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
180+
181+
if is_upscale_model:
182+
# This is an upscale model, so we should keep upscaling until we reach the target size.
183+
iterations = 1
184+
while pil_image.width < target_width or pil_image.height < target_height:
185+
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
186+
iterations += 1
187+
188+
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
189+
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
190+
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
191+
# we should never reach this limit.
192+
if iterations >= 5:
193+
context.logger.warning(
194+
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
195+
)
196+
break
197+
else:
198+
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
199+
# to be the same as the processed image size.
200+
201+
# The output size is now the size of the processed image.
202+
target_width = pil_image.width
203+
target_height = pil_image.height
204+
205+
# Warn the user if they requested a scale greater than 1.
206+
if self.scale > 1:
207+
context.logger.warning(
208+
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
209+
)
210+
211+
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
212+
# in the final resize
213+
if self.fit_to_multiple_of_8:
214+
target_width = int(target_width // 8 * 8)
215+
target_height = int(target_height // 8 * 8)
216+
217+
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
218+
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
219+
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
220+
143221
image_dto = context.images.save(image=pil_image)
144222
return ImageOutput.build(image_dto)

invokeai/frontend/web/public/locales/en.json

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,7 @@
10271027
"imageActions": "Image Actions",
10281028
"sendToImg2Img": "Send to Image to Image",
10291029
"sendToUnifiedCanvas": "Send To Unified Canvas",
1030+
"sendToUpscale": "Send To Upscale",
10301031
"showOptionsPanel": "Show Side Panel (O or T)",
10311032
"shuffle": "Shuffle Seed",
10321033
"steps": "Steps",
@@ -1640,6 +1641,19 @@
16401641
"layers_one": "Layer",
16411642
"layers_other": "Layers"
16421643
},
1644+
"upscaling": {
1645+
"creativity": "Creativity",
1646+
"structure": "Structure",
1647+
"upscaleModel": "Upscale Model",
1648+
"scale": "Scale",
1649+
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
1650+
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",
1651+
"tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture",
1652+
"upscaleModelDesc": "Upscale (image to image) model",
1653+
"missingUpscaleInitialImage": "Missing initial image for upscaling",
1654+
"missingUpscaleModel": "Missing upscale model",
1655+
"missingTileControlNetModel": "No valid tile ControlNet models installed"
1656+
},
16431657
"ui": {
16441658
"tabs": {
16451659
"generation": "Generation",
@@ -1651,7 +1665,9 @@
16511665
"models": "Models",
16521666
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
16531667
"queue": "Queue",
1654-
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
1668+
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
1669+
"upscaling": "Upscaling",
1670+
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
16551671
}
16561672
}
16571673
}

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerM
5252
import type { AppDispatch, RootState } from 'app/store/store';
5353

5454
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
55+
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
5556

5657
export const listenerMiddleware = createListenerMiddleware();
5758

@@ -85,6 +86,7 @@ addGalleryOffsetChangedListener(startAppListening);
8586
addEnqueueRequestedCanvasListener(startAppListening);
8687
addEnqueueRequestedNodes(startAppListening);
8788
addEnqueueRequestedLinear(startAppListening);
89+
addEnqueueRequestedUpscale(startAppListening);
8890
addAnyEnqueuedListener(startAppListening);
8991
addBatchEnqueuedListener(startAppListening);
9092

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import { enqueueRequested } from 'app/store/actions';
2+
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
3+
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
4+
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
5+
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
6+
import { queueApi } from 'services/api/endpoints/queue';
7+
8+
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
9+
startAppListening({
10+
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
11+
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
12+
effect: async (action, { getState, dispatch }) => {
13+
const state = getState();
14+
const { shouldShowProgressInViewer } = state.ui;
15+
const { prepend } = action.payload;
16+
17+
const graph = await buildMultidiffusionUpscaleGraph(state);
18+
19+
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
20+
21+
const req = dispatch(
22+
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
23+
fixedCacheKey: 'enqueueBatch',
24+
})
25+
);
26+
try {
27+
await req.unwrap();
28+
if (shouldShowProgressInViewer) {
29+
dispatch(isImageViewerOpenChanged(true));
30+
}
31+
} finally {
32+
req.reset();
33+
}
34+
},
35+
});
36+
};

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {
2323
} from 'features/gallery/store/gallerySlice';
2424
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
2525
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
26+
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
2627
import { imagesApi } from 'services/api/endpoints/images';
2728

2829
export const dndDropped = createAction<{
@@ -243,6 +244,20 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
243244
return;
244245
}
245246

247+
/**
248+
* Image dropped on upscale initial image
249+
*/
250+
if (
251+
overData.actionType === 'SET_UPSCALE_INITIAL_IMAGE' &&
252+
activeData.payloadType === 'IMAGE_DTO' &&
253+
activeData.payload.imageDTO
254+
) {
255+
const { imageDTO } = activeData.payload;
256+
257+
dispatch(upscaleInitialImageChanged(imageDTO));
258+
return;
259+
}
260+
246261
/**
247262
* Multiple images dropped on user board
248263
*/

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import {
1414
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
1515
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
1616
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
17+
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
1718
import { toast } from 'features/toast/toast';
1819
import { t } from 'i18next';
1920
import { omit } from 'lodash-es';
@@ -89,6 +90,15 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
8990
return;
9091
}
9192

93+
if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
94+
dispatch(upscaleInitialImageChanged(imageDTO));
95+
toast({
96+
...DEFAULT_UPLOADED_TOAST,
97+
description: 'set as upscale initial image',
98+
});
99+
return;
100+
}
101+
92102
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
93103
const { id } = postUploadAction;
94104
dispatch(

0 commit comments

Comments
 (0)