Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ training = [
"gitpython>=3.1.43",
]
test = [
"pytest-rerunfailures>=14.0",
"diffusers>=0.26.1",
"transformers>=4.35.2",
"piq>=0.8.0",
Expand Down
13 changes: 13 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ dill==0.3.8
# via multiprocess
docker-pycreds==0.4.0
# via wandb
exceptiongroup==1.2.2
# via pytest
filelock==3.16.1
# via datasets
# via diffusers
Expand Down Expand Up @@ -113,6 +115,8 @@ importlib-metadata==8.5.0
# via diffusers
importlib-resources==6.4.5
# via swagger-spec-validator
iniconfig==2.0.0
# via pytest
isoduration==20.11.0
# via jsonschema
jaxtyping==0.2.34
Expand Down Expand Up @@ -235,6 +239,8 @@ packaging==24.1
# via huggingface-hub
# via mkdocs
# via neptune
# via pytest
# via pytest-rerunfailures
# via refiners
# via transformers
paginate==0.5.7
Expand All @@ -257,6 +263,8 @@ platformdirs==4.3.6
# via mkdocs-get-deps
# via mkdocstrings
# via wandb
pluggy==1.5.0
# via pytest
prodigyopt==1.0
# via refiners
protobuf==5.28.2
Expand All @@ -279,6 +287,10 @@ pymdown-extensions==10.10.2
# via mkdocstrings
pysocks==1.7.1
# via requests
pytest==8.3.3
# via pytest-rerunfailures
pytest-rerunfailures==14.0
# via refiners
python-dateutil==2.9.0.post0
# via arrow
# via botocore
Expand Down Expand Up @@ -377,6 +389,7 @@ tokenizers==0.20.0
# via transformers
tomli==2.0.1
# via black
# via pytest
# via refiners
torch==2.4.1
# via bitsandbytes
Expand Down
6 changes: 3 additions & 3 deletions src/comfyui-refiners/grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore

from refiners.fluxion.utils import no_grad, tensor_to_image
from refiners.fluxion.utils import no_grad, str_to_dtype, tensor_to_image

from .utils import BoundingBox, get_dtype
from .utils import BoundingBox


class LoadGroundingDino:
Expand Down Expand Up @@ -54,7 +54,7 @@ def load(
processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
assert isinstance(processor, GroundingDinoProcessor)

model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=str_to_dtype(dtype)) # type: ignore
model = model.to(device=device) # type: ignore
assert isinstance(model, GroundingDinoForObjectDetection)

Expand Down
33 changes: 0 additions & 33 deletions src/comfyui-refiners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,39 +48,6 @@ def process(
return (image,)


def get_dtype(dtype: str) -> torch.dtype:
"""Converts a string dtype to a torch.dtype.

See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype"""
match dtype:
case "float32" | "float":
return torch.float32
case "float64" | "double":
return torch.float64
case "complex64" | "cfloat":
return torch.complex64
case "complex128" | "cdouble":
return torch.complex128
case "float16" | "half":
return torch.float16
case "bfloat16":
return torch.bfloat16
case "uint8":
return torch.uint8
case "int8":
return torch.int8
case "int16" | "short":
return torch.int16
case "int32" | "int":
return torch.int32
case "int64" | "long":
return torch.int64
case "bool":
return torch.bool
case _:
raise ValueError(f"Unknown dtype: {dtype}")


NODE_CLASS_MAPPINGS: dict[str, Any] = {
"DrawBoundingBox": DrawBoundingBox,
}
34 changes: 34 additions & 0 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,37 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
)

return "Tensor(" + ", ".join(info_list) + ")"


def str_to_dtype(dtype: str) -> torch.dtype:
"""Converts a string dtype to a torch.dtype.

See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype
"""
match dtype.lower():
case "float32" | "float":
return torch.float32
case "float64" | "double":
return torch.float64
case "complex64" | "cfloat":
return torch.complex64
case "complex128" | "cdouble":
return torch.complex128
case "float16" | "half":
return torch.float16
case "bfloat16":
return torch.bfloat16
case "uint8":
return torch.uint8
case "int8":
return torch.int8
case "int16" | "short":
return torch.int16
case "int32" | "int":
return torch.int32
case "int64" | "long":
return torch.int64
case "bool":
return torch.bool
case _:
raise ValueError(f"Unknown dtype: {dtype}")
22 changes: 21 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from pathlib import Path
from typing import Callable

import torch
from pytest import fixture
from pytest import FixtureRequest, fixture, skip

from refiners.fluxion.utils import str_to_dtype

PARENT_PATH = Path(__file__).parent

Expand All @@ -18,6 +21,23 @@ def test_device() -> torch.device:
return torch.device(test_device)


def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
@fixture(scope="session", params=params)
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
torch_dtype = str_to_dtype(request.param)
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
skip("bfloat16 is not supported on this test device")
return torch_dtype

return dtype_fixture


test_dtype_fp32_bf16_fp16 = dtype_fixture_factory(["float32", "bfloat16", "float16"])
test_dtype_fp32_fp16 = dtype_fixture_factory(["float32", "float16"])
test_dtype_fp32_bf16 = dtype_fixture_factory(["float32", "bfloat16"])
test_dtype_fp16_bf16 = dtype_fixture_factory(["float16", "bfloat16"])


@fixture(scope="session")
def test_weights_path() -> Path:
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")
Expand Down
53 changes: 53 additions & 0 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")


@pytest.fixture
def expected_image_std_random_init_bfloat16(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB")


@pytest.fixture
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
Expand Down Expand Up @@ -637,6 +642,26 @@ def sd15_std_float16(
return sd15


@pytest.fixture
def sd15_std_bfloat16(
text_encoder_weights: Path,
lda_weights: Path,
unet_weights_std: Path,
test_device: torch.device,
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()

sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16)

sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)

return sd15


@pytest.fixture
def sd15_inpainting(
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
Expand Down Expand Up @@ -891,6 +916,34 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init)


@no_grad()
def test_diffusion_std_random_init_bfloat16(
sd15_std_bfloat16: StableDiffusion_1,
expected_image_std_random_init_bfloat16: Image.Image,
):
sd15 = sd15_std_bfloat16

prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)

sd15.set_inference_steps(30)

manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=sd15.device, dtype=sd15.dtype)

for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16)


@no_grad()
def test_diffusion_std_sde_random_init(
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 23 additions & 12 deletions tests/foundationals/clip/test_image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@


@pytest.fixture(scope="module")
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH:
def our_encoder(
test_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> CLIPImageEncoderH:
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPImageEncoderH(device=test_device)
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
tensors = load_from_safetensors(weights)
encoder.load_state_dict(tensors)
return encoder
Expand All @@ -31,24 +35,31 @@ def stabilityai_unclip_weights_path(test_weights_path: Path):


@pytest.fixture(scope="module")
def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection:
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore
test_device # type: ignore
)
def ref_encoder(
stabilityai_unclip_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> CLIPVisionModelWithProjection:
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
stabilityai_unclip_weights_path,
subfolder="image_encoder",
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)


@no_grad()
@pytest.mark.flaky(reruns=3)
def test_encoder(
ref_encoder: CLIPVisionModelWithProjection,
our_encoder: CLIPImageEncoderH,
test_device: torch.device,
):
x = torch.randn(1, 3, 224, 224).to(test_device)
assert ref_encoder.dtype == our_encoder.dtype
assert ref_encoder.device == our_encoder.device
x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device)

with no_grad():
ref_embeddings = ref_encoder(x).image_embeds
our_embeddings = our_encoder(x)
ref_embeddings = ref_encoder(x).image_embeds
our_embeddings = our_encoder(x)

assert ref_embeddings.shape == (1, 1024)
assert our_embeddings.shape == (1, 1024)

assert (our_embeddings - ref_embeddings).abs().max() < 0.01
assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05)
Loading