Skip to content

Commit 6414d2d

Browse files
feat: Update Tile Pre-Processor to support more modes
1 parent 10076fb commit 6414d2d

File tree

3 files changed

+493
-166
lines changed

3 files changed

+493
-166
lines changed

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,47 @@
11
# Invocations for ControlNet image preprocessors
22
# initial implementation by Gregg Helt, 2023
33
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
4+
import random
45
from builtins import bool, float
56
from pathlib import Path
6-
from typing import Dict, List, Literal, Union
7+
from typing import Any, Dict, List, Literal, Union
78

89
import cv2
910
import numpy as np
10-
from controlnet_aux import (
11-
ContentShuffleDetector,
12-
LeresDetector,
13-
MediapipeFaceDetector,
14-
MidasDetector,
15-
MLSDdetector,
16-
NormalBaeDetector,
17-
PidiNetDetector,
18-
SamDetector,
19-
ZoeDetector,
20-
)
11+
from controlnet_aux import (ContentShuffleDetector, LeresDetector,
12+
MediapipeFaceDetector, MidasDetector, MLSDdetector,
13+
NormalBaeDetector, PidiNetDetector, SamDetector,
14+
ZoeDetector)
2115
from controlnet_aux.util import HWC3, ade_palette
2216
from PIL import Image
2317
from pydantic import BaseModel, Field, field_validator, model_validator
2418

25-
from invokeai.app.invocations.fields import (
26-
FieldDescriptions,
27-
ImageField,
28-
InputField,
29-
OutputField,
30-
UIType,
31-
WithBoard,
32-
WithMetadata,
33-
)
19+
from invokeai.app.invocations.fields import (FieldDescriptions, ImageField,
20+
InputField, OutputField, UIType,
21+
WithBoard, WithMetadata)
3422
from invokeai.app.invocations.model import ModelIdentifierField
3523
from invokeai.app.invocations.primitives import ImageOutput
36-
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
24+
from invokeai.app.invocations.util import (validate_begin_end_step,
25+
validate_weights)
3726
from invokeai.app.services.shared.invocation_context import InvocationContext
38-
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
27+
from invokeai.app.util.controlnet_utils import (CONTROLNET_MODE_VALUES,
28+
CONTROLNET_RESIZE_VALUES,
29+
heuristic_resize)
3930
from invokeai.backend.image_util.canny import get_canny_edges
40-
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
41-
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
31+
from invokeai.backend.image_util.depth_anything import (DEPTH_ANYTHING_MODELS,
32+
DepthAnythingDetector)
33+
from invokeai.backend.image_util.dw_openpose import (DWPOSE_MODELS,
34+
DWOpenposeDetector)
35+
from invokeai.backend.image_util.fast_guided_filter.fast_guided_filter import \
36+
FastGuidedFilter
4237
from invokeai.backend.image_util.hed import HEDProcessor
4338
from invokeai.backend.image_util.lineart import LineartProcessor
4439
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
4540
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
4641
from invokeai.backend.util.devices import TorchDevice
4742

48-
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
43+
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
44+
Classification, invocation, invocation_output)
4945

5046

5147
class ControlField(BaseModel):
@@ -483,30 +479,73 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
483479

484480
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
485481
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
482+
mode: Literal["regular", "blur", "var", "super"] = InputField(
483+
default="regular", description="The controlnet tile model being used"
484+
)
485+
486+
def apply_gaussian_blur(self, image_np: np.ndarray[Any, Any], ksize: int = 5, sigmaX: float = 1.0):
487+
if ksize % 2 == 0:
488+
ksize += 1 # ksize must be odd
489+
blurred_image = cv2.GaussianBlur(image_np, (ksize, ksize), sigmaX=sigmaX)
490+
return blurred_image
491+
492+
def apply_guided_filter(self, image_np: np.ndarray[Any, Any], radius: int, eps: float, scale: int):
493+
filter = FastGuidedFilter(image_np, radius, eps, scale)
494+
return filter.filter(image_np)
495+
496+
# based off https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic
497+
def tile_resample(self, np_img: np.ndarray[Any, Any]):
498+
height, width, _ = np_img.shape
486499

487-
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
488-
def tile_resample(
489-
self,
490-
np_img: np.ndarray,
491-
res=512, # never used?
492-
down_sampling_rate=1.0,
493-
):
494-
np_img = HWC3(np_img)
495-
if down_sampling_rate < 1.1:
500+
if self.mode == "regular":
501+
np_img = HWC3(np_img)
502+
if self.down_sampling_rate < 1.1:
503+
return np_img
504+
505+
height = int(float(height) / float(self.down_sampling_rate))
506+
width = int(float(width) / float(self.down_sampling_rate))
507+
np_img = cv2.resize(np_img, (width, height), interpolation=cv2.INTER_AREA)
496508
return np_img
497-
H, W, C = np_img.shape
498-
H = int(float(H) / float(down_sampling_rate))
499-
W = int(float(W) / float(down_sampling_rate))
500-
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
509+
510+
ratio = np.sqrt(1024.0 * 1024.0 / (width * height))
511+
512+
resize_w, resize_h = int(width * ratio), int(height * ratio)
513+
514+
if self.mode == "super":
515+
resize_w, resize_h = int(width * ratio) // 48 * 48, int(height * ratio) // 48 * 48
516+
517+
np_img = cv2.resize(np_img, (resize_w, resize_h))
518+
519+
if self.mode == "blur":
520+
blur_strength = random.sample([i / 10.0 for i in range(10, 201, 2)], k=1)[0]
521+
radius = random.sample([i for i in range(1, 40, 2)], k=1)[0]
522+
eps = random.sample([i / 1000.0 for i in range(1, 101, 2)], k=1)[0]
523+
scale_factor = random.sample([i / 10.0 for i in range(10, 181, 5)], k=1)[0]
524+
525+
if random.random() > 0.5:
526+
np_img = self.apply_gaussian_blur(np_img, ksize=int(blur_strength), sigmaX=blur_strength / 2)
527+
528+
if random.random() > 0.5:
529+
np_img = self.apply_guided_filter(np_img, radius, eps, int(scale_factor))
530+
531+
np_img = cv2.resize(
532+
np_img, (int(resize_w / scale_factor), int(resize_h / scale_factor)), interpolation=cv2.INTER_AREA
533+
)
534+
np_img = cv2.resize(np_img, (resize_w, resize_h), interpolation=cv2.INTER_CUBIC)
535+
536+
if self.mode == "var":
537+
pass
538+
539+
if self.mode == "super":
540+
pass
541+
542+
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
543+
501544
return np_img
502545

503546
def run_processor(self, image: Image.Image) -> Image.Image:
504547
np_img = np.array(image, dtype=np.uint8)
505-
processed_np_image = self.tile_resample(
506-
np_img,
507-
# res=self.tile_size,
508-
down_sampling_rate=self.down_sampling_rate,
509-
)
548+
processed_np_image = self.tile_resample(np_img)
510549
processed_image = Image.fromarray(processed_np_image)
511550
return processed_image
512551

0 commit comments

Comments
 (0)