diff --git a/pyproject.toml b/pyproject.toml index 249a0b88c..a29d5efbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ training = [ "gitpython>=3.1.43", ] test = [ + "pytest-rerunfailures>=14.0", "diffusers>=0.26.1", "transformers>=4.35.2", "piq>=0.8.0", diff --git a/requirements.lock b/requirements.lock index 9db491edd..e56d25afd 100644 --- a/requirements.lock +++ b/requirements.lock @@ -67,6 +67,8 @@ dill==0.3.8 # via multiprocess docker-pycreds==0.4.0 # via wandb +exceptiongroup==1.2.2 + # via pytest filelock==3.16.1 # via datasets # via diffusers @@ -113,6 +115,8 @@ importlib-metadata==8.5.0 # via diffusers importlib-resources==6.4.5 # via swagger-spec-validator +iniconfig==2.0.0 + # via pytest isoduration==20.11.0 # via jsonschema jaxtyping==0.2.34 @@ -235,6 +239,8 @@ packaging==24.1 # via huggingface-hub # via mkdocs # via neptune + # via pytest + # via pytest-rerunfailures # via refiners # via transformers paginate==0.5.7 @@ -257,6 +263,8 @@ platformdirs==4.3.6 # via mkdocs-get-deps # via mkdocstrings # via wandb +pluggy==1.5.0 + # via pytest prodigyopt==1.0 # via refiners protobuf==5.28.2 @@ -279,6 +287,10 @@ pymdown-extensions==10.10.2 # via mkdocstrings pysocks==1.7.1 # via requests +pytest==8.3.3 + # via pytest-rerunfailures +pytest-rerunfailures==14.0 + # via refiners python-dateutil==2.9.0.post0 # via arrow # via botocore @@ -377,6 +389,7 @@ tokenizers==0.20.0 # via transformers tomli==2.0.1 # via black + # via pytest # via refiners torch==2.4.1 # via bitsandbytes diff --git a/src/comfyui-refiners/grounding_dino.py b/src/comfyui-refiners/grounding_dino.py index 46758723a..e32588695 100644 --- a/src/comfyui-refiners/grounding_dino.py +++ b/src/comfyui-refiners/grounding_dino.py @@ -3,9 +3,9 @@ import torch from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore -from refiners.fluxion.utils import no_grad, tensor_to_image +from refiners.fluxion.utils import no_grad, str_to_dtype, tensor_to_image -from .utils import BoundingBox, get_dtype +from .utils import BoundingBox class LoadGroundingDino: @@ -54,7 +54,7 @@ def load( processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore assert isinstance(processor, GroundingDinoProcessor) - model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore + model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=str_to_dtype(dtype)) # type: ignore model = model.to(device=device) # type: ignore assert isinstance(model, GroundingDinoForObjectDetection) diff --git a/src/comfyui-refiners/utils.py b/src/comfyui-refiners/utils.py index 959537614..e1e4613b4 100644 --- a/src/comfyui-refiners/utils.py +++ b/src/comfyui-refiners/utils.py @@ -48,39 +48,6 @@ def process( return (image,) -def get_dtype(dtype: str) -> torch.dtype: - """Converts a string dtype to a torch.dtype. - - See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype""" - match dtype: - case "float32" | "float": - return torch.float32 - case "float64" | "double": - return torch.float64 - case "complex64" | "cfloat": - return torch.complex64 - case "complex128" | "cdouble": - return torch.complex128 - case "float16" | "half": - return torch.float16 - case "bfloat16": - return torch.bfloat16 - case "uint8": - return torch.uint8 - case "int8": - return torch.int8 - case "int16" | "short": - return torch.int16 - case "int32" | "int": - return torch.int32 - case "int64" | "long": - return torch.int64 - case "bool": - return torch.bool - case _: - raise ValueError(f"Unknown dtype: {dtype}") - - NODE_CLASS_MAPPINGS: dict[str, Any] = { "DrawBoundingBox": DrawBoundingBox, } diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index f80589dfd..5de2f9964 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -270,3 +270,37 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str: ) return "Tensor(" + ", ".join(info_list) + ")" + + +def str_to_dtype(dtype: str) -> torch.dtype: + """Converts a string dtype to a torch.dtype. + + See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype + """ + match dtype.lower(): + case "float32" | "float": + return torch.float32 + case "float64" | "double": + return torch.float64 + case "complex64" | "cfloat": + return torch.complex64 + case "complex128" | "cdouble": + return torch.complex128 + case "float16" | "half": + return torch.float16 + case "bfloat16": + return torch.bfloat16 + case "uint8": + return torch.uint8 + case "int8": + return torch.int8 + case "int16" | "short": + return torch.int16 + case "int32" | "int": + return torch.int32 + case "int64" | "long": + return torch.int64 + case "bool": + return torch.bool + case _: + raise ValueError(f"Unknown dtype: {dtype}") diff --git a/tests/conftest.py b/tests/conftest.py index ba56acb12..ef5bad8aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,11 @@ import os from pathlib import Path +from typing import Callable import torch -from pytest import fixture +from pytest import FixtureRequest, fixture, skip + +from refiners.fluxion.utils import str_to_dtype PARENT_PATH = Path(__file__).parent @@ -18,6 +21,23 @@ def test_device() -> torch.device: return torch.device(test_device) +def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]: + @fixture(scope="session", params=params) + def dtype_fixture(request: FixtureRequest) -> torch.dtype: + torch_dtype = str_to_dtype(request.param) + if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + skip("bfloat16 is not supported on this test device") + return torch_dtype + + return dtype_fixture + + +test_dtype_fp32_bf16_fp16 = dtype_fixture_factory(["float32", "bfloat16", "float16"]) +test_dtype_fp32_fp16 = dtype_fixture_factory(["float32", "float16"]) +test_dtype_fp32_bf16 = dtype_fixture_factory(["float32", "bfloat16"]) +test_dtype_fp16_bf16 = dtype_fixture_factory(["float16", "bfloat16"]) + + @fixture(scope="session") def test_weights_path() -> Path: from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR") diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 4fdf4348f..a9e92dfdd 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -92,6 +92,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_random_init.png").convert("RGB") +@pytest.fixture +def expected_image_std_random_init_bfloat16(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB") + + @pytest.fixture def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB") @@ -637,6 +642,26 @@ def sd15_std_float16( return sd15 +@pytest.fixture +def sd15_std_bfloat16( + text_encoder_weights: Path, + lda_weights: Path, + unet_weights_std: Path, + test_device: torch.device, +) -> StableDiffusion_1: + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + + sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16) + + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_std) + + return sd15 + + @pytest.fixture def sd15_inpainting( text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device @@ -891,6 +916,34 @@ def test_diffusion_std_random_init( ensure_similar_images(predicted_image, expected_image_std_random_init) +@no_grad() +def test_diffusion_std_random_init_bfloat16( + sd15_std_bfloat16: StableDiffusion_1, + expected_image_std_random_init_bfloat16: Image.Image, +): + sd15 = sd15_std_bfloat16 + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.set_inference_steps(30) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=sd15.device, dtype=sd15.dtype) + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.latents_to_image(x) + + ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16) + + @no_grad() def test_diffusion_std_sde_random_init( sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device diff --git a/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png b/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png new file mode 100644 index 000000000..c46dd89f1 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png differ diff --git a/tests/foundationals/clip/test_image_encoder.py b/tests/foundationals/clip/test_image_encoder.py index ff990bda5..69988c208 100644 --- a/tests/foundationals/clip/test_image_encoder.py +++ b/tests/foundationals/clip/test_image_encoder.py @@ -10,12 +10,16 @@ @pytest.fixture(scope="module") -def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH: +def our_encoder( + test_weights_path: Path, + test_device: torch.device, + test_dtype_fp32_bf16_fp16: torch.dtype, +) -> CLIPImageEncoderH: weights = test_weights_path / "CLIPImageEncoderH.safetensors" if not weights.is_file(): warn(f"could not find weights at {weights}, skipping") pytest.skip(allow_module_level=True) - encoder = CLIPImageEncoderH(device=test_device) + encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16) tensors = load_from_safetensors(weights) encoder.load_state_dict(tensors) return encoder @@ -31,24 +35,31 @@ def stabilityai_unclip_weights_path(test_weights_path: Path): @pytest.fixture(scope="module") -def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection: - return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore - test_device # type: ignore - ) +def ref_encoder( + stabilityai_unclip_weights_path: Path, + test_device: torch.device, + test_dtype_fp32_bf16_fp16: torch.dtype, +) -> CLIPVisionModelWithProjection: + return CLIPVisionModelWithProjection.from_pretrained( # type: ignore + stabilityai_unclip_weights_path, + subfolder="image_encoder", + ).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) +@no_grad() +@pytest.mark.flaky(reruns=3) def test_encoder( ref_encoder: CLIPVisionModelWithProjection, our_encoder: CLIPImageEncoderH, - test_device: torch.device, ): - x = torch.randn(1, 3, 224, 224).to(test_device) + assert ref_encoder.dtype == our_encoder.dtype + assert ref_encoder.device == our_encoder.device + x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device) - with no_grad(): - ref_embeddings = ref_encoder(x).image_embeds - our_embeddings = our_encoder(x) + ref_embeddings = ref_encoder(x).image_embeds + our_embeddings = our_encoder(x) assert ref_embeddings.shape == (1, 1024) assert our_embeddings.shape == (1, 1024) - assert (our_embeddings - ref_embeddings).abs().max() < 0.01 + assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05) diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index 60a873bb2..28eeada65 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -30,13 +30,17 @@ @pytest.fixture(scope="module") -def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPTextEncoderL: +def our_encoder( + test_weights_path: Path, + test_device: torch.device, + test_dtype_fp32_fp16: torch.dtype, +) -> CLIPTextEncoderL: weights = test_weights_path / "CLIPTextEncoderL.safetensors" if not weights.is_file(): warn(f"could not find weights at {weights}, skipping") pytest.skip(allow_module_level=True) - encoder = CLIPTextEncoderL(device=test_device) tensors = load_from_safetensors(weights) + encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16) encoder.load_state_dict(tensors) return encoder @@ -56,8 +60,15 @@ def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer: @pytest.fixture(scope="module") -def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> transformers.CLIPTextModel: - return transformers.CLIPTextModel.from_pretrained(runwayml_weights_path, subfolder="text_encoder").to(test_device) # type: ignore +def ref_encoder( + runwayml_weights_path: Path, + test_device: torch.device, + test_dtype_fp32_fp16: torch.dtype, +) -> transformers.CLIPTextModel: + return transformers.CLIPTextModel.from_pretrained( # type: ignore + runwayml_weights_path, + subfolder="text_encoder", + ).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL): @@ -70,12 +81,12 @@ def prompt(request: pytest.FixtureRequest): return long_prompt if request.param == "" else request.param +@no_grad() def test_encoder( prompt: str, ref_tokenizer: transformers.CLIPTokenizer, ref_encoder: transformers.CLIPTextModel, our_encoder: CLIPTextEncoderL, - test_device: torch.device, ): ref_tokens = ref_tokenizer( # type: ignore prompt, @@ -89,18 +100,16 @@ def test_encoder( our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens) - with no_grad(): - ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0] - our_embeddings = our_encoder(prompt) + ref_embeddings = ref_encoder(ref_tokens.to(device=ref_encoder.device))[0] + our_embeddings = our_encoder(prompt) assert ref_embeddings.shape == (1, 77, 768) assert our_embeddings.shape == (1, 77, 768) # FG-336 - Not strictly equal because we do not use the same implementation # of self-attention. We use `scaled_dot_product_attention` which can have - # numerical differences depending on the backend. - # Also we use FP16 weights. - assert (our_embeddings - ref_embeddings).abs().max() < 0.01 + # numerical differences depending on the backend. Also we use FP16 weights. + torch.testing.assert_close(our_embeddings, ref_embeddings, atol=0.035, rtol=0.0) def test_list_string_tokenizer( diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 47f7959b0..1e73b6dff 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -109,7 +109,7 @@ def test_dinov2_facebook_weights( ) -> None: manual_seed(2) input_data = torch.randn( - (1, 3, resolution, resolution), + size=(1, 3, resolution, resolution), device=test_device, ) @@ -129,27 +129,28 @@ def test_dinov2_facebook_weights( @no_grad() -def test_dinov2_float16( +def test_dinov2( resolution: int, + test_dtype_fp32_bf16_fp16: torch.dtype, test_device: torch.device, ) -> None: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() - model = DINOv2_small(device=test_device, dtype=torch.float16) + model = DINOv2_small(device=test_device, dtype=test_dtype_fp32_bf16_fp16) manual_seed(2) input_data = torch.randn( - (1, 3, resolution, resolution), + size=(1, 3, resolution, resolution), device=test_device, - dtype=torch.float16, + dtype=test_dtype_fp32_bf16_fp16, ) output = model(input_data) sequence_length = (resolution // model.patch_size) ** 2 + 1 assert output.shape == (1, sequence_length, model.embedding_dim) - assert output.dtype == torch.float16 + assert output.dtype == test_dtype_fp32_bf16_fp16 @no_grad() @@ -162,7 +163,7 @@ def test_dinov2_batch_size( batch_size = 4 manual_seed(2) input_data = torch.randn( - (batch_size, 3, resolution, resolution), + size=(batch_size, 3, resolution, resolution), device=test_device, ) diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py index dc6d77cc9..eedd25112 100644 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ b/tests/foundationals/latent_diffusion/test_auto_encoder.py @@ -6,8 +6,8 @@ from PIL import Image from tests.utils import ensure_similar_images -from refiners.fluxion.utils import load_from_safetensors, no_grad -from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder +from refiners.fluxion.utils import no_grad +from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder @pytest.fixture(scope="module") @@ -15,16 +15,37 @@ def ref_path() -> Path: return Path(__file__).parent / "test_auto_encoder_ref" -@pytest.fixture(scope="module") -def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder: - lda_weights = test_weights_path / "lda.safetensors" - if not lda_weights.is_file(): - warn(f"could not find weights at {lda_weights}, skipping") - pytest.skip(allow_module_level=True) - encoder = LatentDiffusionAutoencoder(device=test_device) - tensors = load_from_safetensors(lda_weights) - encoder.load_state_dict(tensors) - return encoder +@pytest.fixture(scope="module", params=["SD1.5", "SDXL"]) +def lda( + request: pytest.FixtureRequest, + test_weights_path: Path, + test_dtype_fp32_bf16_fp16: torch.dtype, + test_device: torch.device, +) -> LatentDiffusionAutoencoder: + model_version = request.param + match (model_version, test_dtype_fp32_bf16_fp16): + case ("SD1.5", _): + weight_path = test_weights_path / "lda.safetensors" + if not weight_path.is_file(): + warn(f"could not find weights at {weight_path}, skipping") + pytest.skip(allow_module_level=True) + model = SD1Autoencoder().load_from_safetensors(weight_path) + case ("SDXL", torch.float16): + weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors" + if not weight_path.is_file(): + warn(f"could not find weights at {weight_path}, skipping") + pytest.skip(allow_module_level=True) + model = SDXLAutoencoder().load_from_safetensors(weight_path) + case ("SDXL", _): + weight_path = test_weights_path / "sdxl-lda.safetensors" + if not weight_path.is_file(): + warn(f"could not find weights at {weight_path}, skipping") + pytest.skip(allow_module_level=True) + model = SDXLAutoencoder().load_from_safetensors(weight_path) + case _: + raise ValueError(f"Unknown model version: {model_version}") + model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + return model @pytest.fixture(scope="module") diff --git a/tests/foundationals/latent_diffusion/test_models.py b/tests/foundationals/latent_diffusion/test_models.py new file mode 100644 index 000000000..5fd74c878 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_models.py @@ -0,0 +1,77 @@ +import torch +from PIL import Image + +from refiners.fluxion.utils import manual_seed, no_grad +from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting, StableDiffusion_XL +from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel + + +@no_grad() +def test_sample_noise_zero_offset(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: + manual_seed(2) + latents_0 = LatentDiffusionModel.sample_noise( + size=(1, 4, 64, 64), + device=test_device, + dtype=test_dtype_fp32_bf16_fp16, + ) + manual_seed(2) + latents_1 = LatentDiffusionModel.sample_noise( + size=(1, 4, 64, 64), + offset_noise=0.0, # should be no-op + device=test_device, + dtype=test_dtype_fp32_bf16_fp16, + ) + + assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0) + + +@no_grad() +def test_sd15_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: + sd = StableDiffusion_1(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + + # prepare inputs + latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16) + text_embedding = sd.compute_clip_text_embedding("") + + # run the pipeline of models, for a single step + output = sd(latent_noise, step=0, clip_text_embedding=text_embedding) + + assert output.shape == (1, 4, 64, 64) + + +@no_grad() +def test_sd15_inpainting_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: + sd = StableDiffusion_1_Inpainting(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + + # prepare inputs + latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16) + target_image = Image.new("RGB", (512, 512)) + mask = Image.new("L", (512, 512)) + sd.set_inpainting_conditions(target_image=target_image, mask=mask) + text_embedding = sd.compute_clip_text_embedding("") + + # run the pipeline of models, for a single step + output = sd(latent_noise, step=0, clip_text_embedding=text_embedding) + + assert output.shape == (1, 4, 64, 64) + + +@no_grad() +def test_sdxl_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: + sd = StableDiffusion_XL(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + + # prepare inputs + latent_noise = torch.randn(1, 4, 128, 128, device=test_device, dtype=test_dtype_fp32_bf16_fp16) + text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding("") + time_ids = sd.default_time_ids + + # run the pipeline of models, for a single step + output = sd( + latent_noise, + step=0, + clip_text_embedding=text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + ) + + assert output.shape == (1, 4, 128, 128) diff --git a/tests/foundationals/latent_diffusion/test_sd15_unet.py b/tests/foundationals/latent_diffusion/test_sd15_unet.py index 3ecf1ae4e..6c01a5465 100644 --- a/tests/foundationals/latent_diffusion/test_sd15_unet.py +++ b/tests/foundationals/latent_diffusion/test_sd15_unet.py @@ -7,9 +7,15 @@ @pytest.fixture(scope="module") -def refiners_sd15_unet(test_device: torch.device) -> SD1UNet: - unet = SD1UNet(in_channels=4, device=test_device) - return unet +def refiners_sd15_unet( + test_device: torch.device, + test_dtype_fp32_bf16_fp16: torch.dtype, +) -> SD1UNet: + return SD1UNet( + in_channels=4, + device=test_device, + dtype=test_dtype_fp32_bf16_fp16, + ) def test_unet_context_flush(refiners_sd15_unet: SD1UNet):