Skip to content

Commit 2a2cc73

Browse files
ankitadefacebook-github-bot
authored andcommitted
Separate out text and image encoders (facebookresearch#115)
Summary: Pull Request resolved: facebookresearch#115 Pull Request resolved: facebookresearch#102 Separate out the encoders into their own module without ay logic changes (except fixing 2 minor bugs, see annotations by me) and add tests Test Plan: pytest Reviewed By: ebsmothers Differential Revision: D37407717 Pulled By: ankitade fbshipit-source-id: cd9e120eea4890bb813cb8bbe77577f9e2c77c40
1 parent d16ae39 commit 2a2cc73

File tree

8 files changed

+840
-490
lines changed

8 files changed

+840
-490
lines changed

mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace_packages = True
1414
install_types = True
1515

1616
# TODO (T116951827): Remove after fixing FLAVA type check errors
17-
exclude = models/flava/flava_model.py|modules/losses/flava.py
17+
exclude = models/flava/flava_model.py|models/flava/flava_text_encoder.py|modules/losses/flava.py
1818

1919
[mypy-PIL.*]
2020
ignore_missing_imports = True

test/models/flava/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
File renamed without changes.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from test.test_utils import assert_expected, set_rng_seed
11+
from torch import nn
12+
from torchmultimodal.models.flava.flava_image_encoder import (
13+
ImageEmbeddings,
14+
ImageTransformer,
15+
)
16+
from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder
17+
18+
19+
class TestFlavaImageEncoder(unittest.TestCase):
20+
def setUp(self):
21+
set_rng_seed(0)
22+
self.image_embedding = ImageEmbeddings(
23+
image_size=2, patch_size=1, hidden_size=2
24+
)
25+
26+
encoder = FLAVATransformerEncoder(
27+
hidden_size=2,
28+
num_attention_heads=1,
29+
num_hidden_layers=1,
30+
hidden_dropout_prob=0.0,
31+
intermediate_size=1,
32+
attention_probs_dropout_prob=0.0,
33+
)
34+
self.image_encoder = ImageTransformer(
35+
embeddings=self.image_embedding,
36+
encoder=encoder,
37+
layernorm=nn.LayerNorm(2),
38+
pooler=nn.Identity(),
39+
)
40+
41+
def test_embedding(self):
42+
input = torch.ones(2, 3, 2, 2)
43+
out = self.image_embedding(input)
44+
assert_expected(
45+
out,
46+
torch.Tensor(
47+
[
48+
[
49+
[0.0000, 0.0000],
50+
[0.0224, 0.0573],
51+
[0.0224, 0.0573],
52+
[0.0224, 0.0573],
53+
[0.0224, 0.0573],
54+
],
55+
[
56+
[0.0000, 0.0000],
57+
[0.0224, 0.0573],
58+
[0.0224, 0.0573],
59+
[0.0224, 0.0573],
60+
[0.0224, 0.0573],
61+
],
62+
]
63+
),
64+
atol=1e-4,
65+
rtol=0,
66+
)
67+
68+
def test_image_encoder(self):
69+
input = torch.ones(2, 3, 2, 2)
70+
out = self.image_encoder(input)
71+
assert_expected(
72+
out.last_hidden_state,
73+
torch.Tensor(
74+
[
75+
[
76+
[-0.0040, 0.0040],
77+
[-0.9840, 0.9840],
78+
[-0.9840, 0.9840],
79+
[-0.9840, 0.9840],
80+
[-0.9840, 0.9840],
81+
],
82+
[
83+
[-0.0040, 0.0040],
84+
[-0.9840, 0.9840],
85+
[-0.9840, 0.9840],
86+
[-0.9840, 0.9840],
87+
[-0.9840, 0.9840],
88+
],
89+
]
90+
),
91+
atol=1e-4,
92+
rtol=0,
93+
)
94+
assert_expected(out.pooler_output, out.last_hidden_state)
95+
assert_expected(
96+
out.hidden_states,
97+
(
98+
torch.Tensor(
99+
[
100+
[
101+
[0.0000, 0.0000],
102+
[0.0224, 0.0573],
103+
[0.0224, 0.0573],
104+
[0.0224, 0.0573],
105+
[0.0224, 0.0573],
106+
],
107+
[
108+
[0.0000, 0.0000],
109+
[0.0224, 0.0573],
110+
[0.0224, 0.0573],
111+
[0.0224, 0.0573],
112+
[0.0224, 0.0573],
113+
],
114+
]
115+
),
116+
torch.Tensor(
117+
[
118+
[
119+
[0.0008, 0.0008],
120+
[0.0232, 0.0581],
121+
[0.0232, 0.0581],
122+
[0.0232, 0.0581],
123+
[0.0232, 0.0581],
124+
],
125+
[
126+
[0.0008, 0.0008],
127+
[0.0232, 0.0581],
128+
[0.0232, 0.0581],
129+
[0.0232, 0.0581],
130+
[0.0232, 0.0581],
131+
],
132+
]
133+
),
134+
),
135+
atol=1e-4,
136+
rtol=0,
137+
)
138+
assert_expected(
139+
out.attentions,
140+
(
141+
torch.Tensor(
142+
[
143+
[
144+
[
145+
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
146+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
147+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
148+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
149+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
150+
]
151+
],
152+
[
153+
[
154+
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
155+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
156+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
157+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
158+
[0.1999, 0.2000, 0.2000, 0.2000, 0.2000],
159+
]
160+
],
161+
]
162+
),
163+
),
164+
atol=1e-4,
165+
rtol=0,
166+
)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from test.test_utils import assert_expected, set_rng_seed
11+
from torch import nn
12+
from torchmultimodal.models.flava.flava_text_encoder import (
13+
TextEmbeddings,
14+
TextTransformer,
15+
)
16+
from torchmultimodal.modules.layers.transformer import FLAVATransformerEncoder
17+
18+
19+
class TestFlavaTextEncoder(unittest.TestCase):
20+
def setUp(self):
21+
set_rng_seed(0)
22+
self.text_embedding = TextEmbeddings(
23+
hidden_size=2,
24+
vocab_size=3,
25+
max_position_embeddings=2,
26+
hidden_dropout_prob=0,
27+
)
28+
emb_weights = torch.Tensor([[0, 1], [1, 0], [1, 1]])
29+
self.text_embedding.word_embeddings = nn.Embedding.from_pretrained(emb_weights)
30+
self.text_embedding.position_embeddings = nn.Embedding.from_pretrained(
31+
emb_weights
32+
)
33+
self.text_embedding.token_type_embeddings = nn.Embedding.from_pretrained(
34+
emb_weights
35+
)
36+
37+
encoder = FLAVATransformerEncoder(
38+
hidden_size=2,
39+
num_attention_heads=1,
40+
num_hidden_layers=1,
41+
hidden_dropout_prob=0.0,
42+
intermediate_size=1,
43+
attention_probs_dropout_prob=0.0,
44+
)
45+
self.text_encoder = TextTransformer(
46+
embeddings=self.text_embedding,
47+
encoder=encoder,
48+
layernorm=nn.LayerNorm(2),
49+
pooler=nn.Identity(),
50+
)
51+
52+
def test_embedding(self):
53+
input_ids = torch.IntTensor([[0, 1]])
54+
out = self.text_embedding(input_ids)
55+
expected = torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
56+
assert_expected(out, expected)
57+
58+
def test_text_transformer(self):
59+
out = self.text_encoder(torch.IntTensor([[0, 1]]))
60+
61+
assert_expected(
62+
out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
63+
)
64+
65+
assert_expected(
66+
out.hidden_states,
67+
(
68+
torch.Tensor([[[1.0000, -1.0000], [-1.0000, 1.0000]]]),
69+
torch.Tensor([[[1.0008, -0.9994], [-0.9997, 1.0012]]]),
70+
),
71+
atol=1e-4,
72+
rtol=0.0,
73+
)
74+
75+
assert_expected(out.attentions, (torch.Tensor([[[[0, 1.0], [0.0, 1.0]]]]),))
76+
77+
def test_text_transformer_attn_mask(self):
78+
input_ids = torch.IntTensor([[0, 1]])
79+
attn_mask = torch.IntTensor([[1, 0]])
80+
out = self.text_encoder(input_ids, attention_mask=attn_mask)
81+
82+
assert_expected(
83+
out.last_hidden_state, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]])
84+
)
85+
86+
assert_expected(
87+
out.hidden_states,
88+
(
89+
torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]),
90+
torch.Tensor([[[0.9997, -1.0012], [-1.0008, 0.9994]]]),
91+
),
92+
atol=1e-4,
93+
rtol=0.0,
94+
)
95+
96+
assert_expected(out.pooler_output, torch.Tensor([[[1.0, -1.0], [-1.0, 1.0]]]))
97+
assert_expected(out.attentions, (torch.Tensor([[[[1.0, 0], [1.0, 0]]]]),))

0 commit comments

Comments
 (0)