-
Notifications
You must be signed in to change notification settings - Fork 247
Open
Description
Hello, Thank you for your great works!
ml-cvnets/cvnets/layers/linear_attention.py
Lines 163 to 191 in 7771756
| def _forward_cross_attn( | |
| self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs | |
| ) -> Tensor: | |
| # x --> [B, C, P, N] | |
| # x_prev = [B, C, P, M] | |
| batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape | |
| q_patch_area, q_num_patches = x.shape[-2:] | |
| assert ( | |
| kv_patch_area == q_patch_area | |
| ), "The number of pixels in a patch for query and key_value should be the same" | |
| # compute query, key, and value | |
| # [B, C, P, M] --> [B, 1 + d, P, M] | |
| qk = F.conv2d( | |
| x_prev, | |
| weight=self.qkv_proj.block.conv.weight[: self.embed_dim + 1, ...], | |
| bias=self.qkv_proj.block.conv.bias[: self.embed_dim + 1, ...], | |
| ) | |
| # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M] | |
| query, key = torch.split(qk, split_size_or_sections=[1, self.embed_dim], dim=1) | |
| # [B, C, P, N] --> [B, d, P, N] | |
| value = F.conv2d( | |
| x, | |
| weight=self.qkv_proj.block.conv.weight[self.embed_dim + 1 :, ...], | |
| bias=self.qkv_proj.block.conv.bias[self.embed_dim + 1 :, ...], | |
| ) |
As I am looking at the implementation for the MobileViTV2 Linear Attention, I saw query and key values are generated from x_prev to calculate context vector and value is generated from x.
On the paper context vector is analogous to attention matrix. Then, when we calculate context vector, query should be from x_prev and key should be x to get context vector to get similarity between x_previous and x?
Metadata
Metadata
Assignees
Labels
No labels