|
1 | 1 | # Invocations for ControlNet image preprocessors
|
2 | 2 | # initial implementation by Gregg Helt, 2023
|
3 | 3 | # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
| 4 | +import random |
4 | 5 | from builtins import bool, float
|
5 | 6 | from pathlib import Path
|
6 |
| -from typing import Dict, List, Literal, Union |
| 7 | +from typing import Any, Dict, List, Literal, Union |
7 | 8 |
|
8 | 9 | import cv2
|
9 | 10 | import numpy as np
|
10 |
| -from controlnet_aux import ( |
11 |
| - ContentShuffleDetector, |
12 |
| - LeresDetector, |
13 |
| - MediapipeFaceDetector, |
14 |
| - MidasDetector, |
15 |
| - MLSDdetector, |
16 |
| - NormalBaeDetector, |
17 |
| - PidiNetDetector, |
18 |
| - SamDetector, |
19 |
| - ZoeDetector, |
20 |
| -) |
| 11 | +from controlnet_aux import (ContentShuffleDetector, LeresDetector, |
| 12 | + MediapipeFaceDetector, MidasDetector, MLSDdetector, |
| 13 | + NormalBaeDetector, PidiNetDetector, SamDetector, |
| 14 | + ZoeDetector) |
21 | 15 | from controlnet_aux.util import HWC3, ade_palette
|
22 | 16 | from PIL import Image
|
23 | 17 | from pydantic import BaseModel, Field, field_validator, model_validator
|
24 | 18 |
|
25 |
| -from invokeai.app.invocations.fields import ( |
26 |
| - FieldDescriptions, |
27 |
| - ImageField, |
28 |
| - InputField, |
29 |
| - OutputField, |
30 |
| - UIType, |
31 |
| - WithBoard, |
32 |
| - WithMetadata, |
33 |
| -) |
| 19 | +from invokeai.app.invocations.fields import (FieldDescriptions, ImageField, |
| 20 | + InputField, OutputField, UIType, |
| 21 | + WithBoard, WithMetadata) |
34 | 22 | from invokeai.app.invocations.model import ModelIdentifierField
|
35 | 23 | from invokeai.app.invocations.primitives import ImageOutput
|
36 |
| -from invokeai.app.invocations.util import validate_begin_end_step, validate_weights |
| 24 | +from invokeai.app.invocations.util import (validate_begin_end_step, |
| 25 | + validate_weights) |
37 | 26 | from invokeai.app.services.shared.invocation_context import InvocationContext
|
38 |
| -from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize |
| 27 | +from invokeai.app.util.controlnet_utils import (CONTROLNET_MODE_VALUES, |
| 28 | + CONTROLNET_RESIZE_VALUES, |
| 29 | + heuristic_resize) |
39 | 30 | from invokeai.backend.image_util.canny import get_canny_edges
|
40 |
| -from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector |
41 |
| -from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector |
| 31 | +from invokeai.backend.image_util.depth_anything import (DEPTH_ANYTHING_MODELS, |
| 32 | + DepthAnythingDetector) |
| 33 | +from invokeai.backend.image_util.dw_openpose import (DWPOSE_MODELS, |
| 34 | + DWOpenposeDetector) |
| 35 | +from invokeai.backend.image_util.fast_guided_filter.fast_guided_filter import \ |
| 36 | + FastGuidedFilter |
42 | 37 | from invokeai.backend.image_util.hed import HEDProcessor
|
43 | 38 | from invokeai.backend.image_util.lineart import LineartProcessor
|
44 | 39 | from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
45 | 40 | from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
46 | 41 | from invokeai.backend.util.devices import TorchDevice
|
47 | 42 |
|
48 |
| -from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output |
| 43 | +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, |
| 44 | + Classification, invocation, invocation_output) |
49 | 45 |
|
50 | 46 |
|
51 | 47 | class ControlField(BaseModel):
|
@@ -483,30 +479,73 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
483 | 479 |
|
484 | 480 | # res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
485 | 481 | down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
| 482 | + mode: Literal["regular", "blur", "var", "super"] = InputField( |
| 483 | + default="regular", description="The controlnet tile model being used" |
| 484 | + ) |
| 485 | + |
| 486 | + def apply_gaussian_blur(self, image_np: np.ndarray[Any, Any], ksize: int = 5, sigmaX: float = 1.0): |
| 487 | + if ksize % 2 == 0: |
| 488 | + ksize += 1 # ksize must be odd |
| 489 | + blurred_image = cv2.GaussianBlur(image_np, (ksize, ksize), sigmaX=sigmaX) |
| 490 | + return blurred_image |
| 491 | + |
| 492 | + def apply_guided_filter(self, image_np: np.ndarray[Any, Any], radius: int, eps: float, scale: int): |
| 493 | + filter = FastGuidedFilter(image_np, radius, eps, scale) |
| 494 | + return filter.filter(image_np) |
| 495 | + |
| 496 | + # based off https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic |
| 497 | + def tile_resample(self, np_img: np.ndarray[Any, Any]): |
| 498 | + height, width, _ = np_img.shape |
486 | 499 |
|
487 |
| - # tile_resample copied from sd-webui-controlnet/scripts/processor.py |
488 |
| - def tile_resample( |
489 |
| - self, |
490 |
| - np_img: np.ndarray, |
491 |
| - res=512, # never used? |
492 |
| - down_sampling_rate=1.0, |
493 |
| - ): |
494 |
| - np_img = HWC3(np_img) |
495 |
| - if down_sampling_rate < 1.1: |
| 500 | + if self.mode == "regular": |
| 501 | + np_img = HWC3(np_img) |
| 502 | + if self.down_sampling_rate < 1.1: |
| 503 | + return np_img |
| 504 | + |
| 505 | + height = int(float(height) / float(self.down_sampling_rate)) |
| 506 | + width = int(float(width) / float(self.down_sampling_rate)) |
| 507 | + np_img = cv2.resize(np_img, (width, height), interpolation=cv2.INTER_AREA) |
496 | 508 | return np_img
|
497 |
| - H, W, C = np_img.shape |
498 |
| - H = int(float(H) / float(down_sampling_rate)) |
499 |
| - W = int(float(W) / float(down_sampling_rate)) |
500 |
| - np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) |
| 509 | + |
| 510 | + ratio = np.sqrt(1024.0 * 1024.0 / (width * height)) |
| 511 | + |
| 512 | + resize_w, resize_h = int(width * ratio), int(height * ratio) |
| 513 | + |
| 514 | + if self.mode == "super": |
| 515 | + resize_w, resize_h = int(width * ratio) // 48 * 48, int(height * ratio) // 48 * 48 |
| 516 | + |
| 517 | + np_img = cv2.resize(np_img, (resize_w, resize_h)) |
| 518 | + |
| 519 | + if self.mode == "blur": |
| 520 | + blur_strength = random.sample([i / 10.0 for i in range(10, 201, 2)], k=1)[0] |
| 521 | + radius = random.sample([i for i in range(1, 40, 2)], k=1)[0] |
| 522 | + eps = random.sample([i / 1000.0 for i in range(1, 101, 2)], k=1)[0] |
| 523 | + scale_factor = random.sample([i / 10.0 for i in range(10, 181, 5)], k=1)[0] |
| 524 | + |
| 525 | + if random.random() > 0.5: |
| 526 | + np_img = self.apply_gaussian_blur(np_img, ksize=int(blur_strength), sigmaX=blur_strength / 2) |
| 527 | + |
| 528 | + if random.random() > 0.5: |
| 529 | + np_img = self.apply_guided_filter(np_img, radius, eps, int(scale_factor)) |
| 530 | + |
| 531 | + np_img = cv2.resize( |
| 532 | + np_img, (int(resize_w / scale_factor), int(resize_h / scale_factor)), interpolation=cv2.INTER_AREA |
| 533 | + ) |
| 534 | + np_img = cv2.resize(np_img, (resize_w, resize_h), interpolation=cv2.INTER_CUBIC) |
| 535 | + |
| 536 | + if self.mode == "var": |
| 537 | + pass |
| 538 | + |
| 539 | + if self.mode == "super": |
| 540 | + pass |
| 541 | + |
| 542 | + np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) |
| 543 | + |
501 | 544 | return np_img
|
502 | 545 |
|
503 | 546 | def run_processor(self, image: Image.Image) -> Image.Image:
|
504 | 547 | np_img = np.array(image, dtype=np.uint8)
|
505 |
| - processed_np_image = self.tile_resample( |
506 |
| - np_img, |
507 |
| - # res=self.tile_size, |
508 |
| - down_sampling_rate=self.down_sampling_rate, |
509 |
| - ) |
| 548 | + processed_np_image = self.tile_resample(np_img) |
510 | 549 | processed_image = Image.fromarray(processed_np_image)
|
511 | 550 | return processed_image
|
512 | 551 |
|
|
0 commit comments