Skip to content

Commit bbbb91d

Browse files
authored
Disable weight computation in self-attention for TransformerDecoderLayer (#4398)
Disable weight computation in self-attention for TransformerDecoderLayer in dfine_decoder.py and rtdetr_decoder.py
1 parent d0803ab commit bbbb91d

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/otx/algo/detection/heads/dfine_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def forward(
117117
# self attention
118118
q = k = self.with_pos_embed(target, query_pos_embed)
119119

120-
target2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask)
120+
target2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask, need_weights=False)
121121
target = target + self.dropout1(target2)
122122
target = self.norm1(target)
123123

src/otx/algo/detection/heads/rtdetr_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def forward(
196196
# self attention
197197
q = k = self.with_pos_embed(tgt, query_pos_embed)
198198

199-
tgt2, _ = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
199+
tgt2, _ = self.self_attn(q, k, value=tgt, attn_mask=attn_mask, need_weights=False)
200200
tgt = tgt + self.dropout1(tgt2)
201201
tgt = self.norm1(tgt)
202202

0 commit comments

Comments
 (0)