@@ -513,12 +513,26 @@ def lda_ft_mse_weights(test_weights_path: Path) -> Path:
513
513
514
514
515
515
@pytest .fixture (scope = "module" )
516
- def ella_adapter_weights (test_weights_path : Path ) -> Path :
516
+ def ella_weights (test_weights_path : Path ) -> tuple [ Path , Path ] :
517
517
ella_adapter_weights = test_weights_path / "ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors"
518
518
if not ella_adapter_weights .is_file ():
519
519
warn (f"could not find weights at { ella_adapter_weights } , skipping" )
520
520
pytest .skip (allow_module_level = True )
521
- return ella_adapter_weights
521
+ t5xl_weights = test_weights_path / "QQGYLab/T5XLFP16"
522
+ t5xl_files = [
523
+ "config.json" ,
524
+ "model.safetensors" ,
525
+ "special_tokens_map.json" ,
526
+ "spiece.model" ,
527
+ "tokenizer_config.json" ,
528
+ "tokenizer.json" ,
529
+ ]
530
+ for file in t5xl_files :
531
+ if not (t5xl_weights / file ).is_file ():
532
+ warn (f"could not find weights at { t5xl_weights / file } , skipping" )
533
+ pytest .skip (allow_module_level = True )
534
+
535
+ return (ella_adapter_weights , t5xl_weights )
522
536
523
537
524
538
@pytest .fixture (scope = "module" )
@@ -605,10 +619,6 @@ def sd15_std_sde(
605
619
def sd15_std_float16 (
606
620
text_encoder_weights : Path , lda_weights : Path , unet_weights_std : Path , test_device : torch .device
607
621
) -> StableDiffusion_1 :
608
- if test_device .type == "cpu" :
609
- warn ("not running on CPU, skipping" )
610
- pytest .skip ()
611
-
612
622
sd15 = StableDiffusion_1 (device = test_device , dtype = torch .float16 )
613
623
614
624
sd15 .clip_text_encoder .load_from_safetensors (text_encoder_weights )
@@ -1817,15 +1827,13 @@ def test_diffusion_textual_inversion_random_init(
1817
1827
@no_grad ()
1818
1828
def test_diffusion_ella_adapter (
1819
1829
sd15_std_float16 : StableDiffusion_1 ,
1820
- ella_adapter_weights : Path ,
1821
- test_weights_path : Path ,
1830
+ ella_weights : tuple [Path , Path ],
1822
1831
expected_image_ella_adapter : Image .Image ,
1823
1832
test_device : torch .device ,
1824
1833
):
1825
1834
sd15 = sd15_std_float16
1826
- t5_encoder = T5TextEmbedder (pretrained_path = test_weights_path / "QQGYLab/T5XLFP16" , max_length = 128 ).to (
1827
- test_device , torch .float16
1828
- )
1835
+ ella_adapter_weights , t5xl_weights = ella_weights
1836
+ t5_encoder = T5TextEmbedder (pretrained_path = t5xl_weights , max_length = 128 ).to (test_device , torch .float16 )
1829
1837
1830
1838
prompt = "a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region"
1831
1839
negative_prompt = ""
0 commit comments