Skip to content

Using Cross Attenion in linear attention of MobileViT V2 #109

@jhkwag970

Description

@jhkwag970

Hello, Thank you for your great works!

def _forward_cross_attn(
self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs
) -> Tensor:
# x --> [B, C, P, N]
# x_prev = [B, C, P, M]
batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
q_patch_area, q_num_patches = x.shape[-2:]
assert (
kv_patch_area == q_patch_area
), "The number of pixels in a patch for query and key_value should be the same"
# compute query, key, and value
# [B, C, P, M] --> [B, 1 + d, P, M]
qk = F.conv2d(
x_prev,
weight=self.qkv_proj.block.conv.weight[: self.embed_dim + 1, ...],
bias=self.qkv_proj.block.conv.bias[: self.embed_dim + 1, ...],
)
# [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
query, key = torch.split(qk, split_size_or_sections=[1, self.embed_dim], dim=1)
# [B, C, P, N] --> [B, d, P, N]
value = F.conv2d(
x,
weight=self.qkv_proj.block.conv.weight[self.embed_dim + 1 :, ...],
bias=self.qkv_proj.block.conv.bias[self.embed_dim + 1 :, ...],
)

As I am looking at the implementation for the MobileViTV2 Linear Attention, I saw query and key values are generated from x_prev to calculate context vector and value is generated from x.

On the paper context vector is analogous to attention matrix. Then, when we calculate context vector, query should be from x_prev and key should be x to get context vector to get similarity between x_previous and x?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions