30
30
31
31
32
32
@pytest .fixture (scope = "module" )
33
- def our_encoder (test_weights_path : Path , test_device : torch .device ) -> CLIPTextEncoderL :
33
+ def our_encoder (
34
+ test_weights_path : Path ,
35
+ test_device : torch .device ,
36
+ test_dtype_fp32_fp16 : torch .dtype ,
37
+ ) -> CLIPTextEncoderL :
34
38
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
35
39
if not weights .is_file ():
36
40
warn (f"could not find weights at { weights } , skipping" )
37
41
pytest .skip (allow_module_level = True )
38
- encoder = CLIPTextEncoderL (device = test_device )
39
42
tensors = load_from_safetensors (weights )
43
+ encoder = CLIPTextEncoderL (device = test_device , dtype = test_dtype_fp32_fp16 )
40
44
encoder .load_state_dict (tensors )
41
45
return encoder
42
46
@@ -56,8 +60,15 @@ def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
56
60
57
61
58
62
@pytest .fixture (scope = "module" )
59
- def ref_encoder (runwayml_weights_path : Path , test_device : torch .device ) -> transformers .CLIPTextModel :
60
- return transformers .CLIPTextModel .from_pretrained (runwayml_weights_path , subfolder = "text_encoder" ).to (test_device ) # type: ignore
63
+ def ref_encoder (
64
+ runwayml_weights_path : Path ,
65
+ test_device : torch .device ,
66
+ test_dtype_fp32_fp16 : torch .dtype ,
67
+ ) -> transformers .CLIPTextModel :
68
+ return transformers .CLIPTextModel .from_pretrained ( # type: ignore
69
+ runwayml_weights_path ,
70
+ subfolder = "text_encoder" ,
71
+ ).to (device = test_device , dtype = test_dtype_fp32_fp16 ) # type: ignore
61
72
62
73
63
74
def test_basics (ref_tokenizer : transformers .CLIPTokenizer , our_encoder : CLIPTextEncoderL ):
@@ -70,12 +81,12 @@ def prompt(request: pytest.FixtureRequest):
70
81
return long_prompt if request .param == "<long prompt>" else request .param
71
82
72
83
84
+ @no_grad ()
73
85
def test_encoder (
74
86
prompt : str ,
75
87
ref_tokenizer : transformers .CLIPTokenizer ,
76
88
ref_encoder : transformers .CLIPTextModel ,
77
89
our_encoder : CLIPTextEncoderL ,
78
- test_device : torch .device ,
79
90
):
80
91
ref_tokens = ref_tokenizer ( # type: ignore
81
92
prompt ,
@@ -89,18 +100,16 @@ def test_encoder(
89
100
our_tokens = tokenizer (prompt )
90
101
assert torch .equal (our_tokens , ref_tokens )
91
102
92
- with no_grad ():
93
- ref_embeddings = ref_encoder (ref_tokens .to (test_device ))[0 ]
94
- our_embeddings = our_encoder (prompt )
103
+ ref_embeddings = ref_encoder (ref_tokens .to (device = ref_encoder .device ))[0 ]
104
+ our_embeddings = our_encoder (prompt )
95
105
96
106
assert ref_embeddings .shape == (1 , 77 , 768 )
97
107
assert our_embeddings .shape == (1 , 77 , 768 )
98
108
99
109
# FG-336 - Not strictly equal because we do not use the same implementation
100
110
# of self-attention. We use `scaled_dot_product_attention` which can have
101
- # numerical differences depending on the backend.
102
- # Also we use FP16 weights.
103
- assert (our_embeddings - ref_embeddings ).abs ().max () < 0.01
111
+ # numerical differences depending on the backend. Also we use FP16 weights.
112
+ torch .testing .assert_close (our_embeddings , ref_embeddings , atol = 0.035 , rtol = 0.0 )
104
113
105
114
106
115
def test_list_string_tokenizer (
0 commit comments