Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 162 additions & 1 deletion tests/modules/layers/test_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@
import pytest

import torch
from tests.test_utils import assert_expected
from tests.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.modules.layers.position_embedding import (
AlibiPositionEmbeddings,
BroadcastedPositionEmbedding,
SinusoidalPositionEmbeddings,
)


@pytest.fixture(autouse=True)
def random():
set_rng_seed(2023)


class TestBroadcastedPositionEmbedding:
@pytest.fixture(scope="class")
def pos_emb(self):
Expand Down Expand Up @@ -112,3 +118,158 @@ def test_forward(self, data, emb):
actual = emb(data)
expected = torch.Size([3, 5])
assert_expected(actual.shape, expected)


class TestAlibiPositionEmbedding:
@pytest.fixture
def max_seq_len(self):
return 16

@pytest.fixture
def embedding_dim(self):
return 32

@pytest.fixture
def num_heads(self):
return 8

@pytest.fixture
def num_heads_non_power_2(self):
return 12

def test_alibi_mask_power_of_2(
self,
max_seq_len,
num_heads,
):
alibi_class = AlibiPositionEmbeddings(
max_seq_len=max_seq_len, num_heads=num_heads
)
base_mask = alibi_class.get_attention_mask(max_seq_len)

# verify mask shape
expected_shape = torch.Size((num_heads, max_seq_len, max_seq_len))
assert_expected(base_mask.shape, expected_shape)

# verify alibi mask components
expected_last_head_row = torch.tensor(
[
-0.0586,
-0.0547,
-0.0508,
-0.0469,
-0.0430,
-0.0391,
-0.0352,
-0.0312,
-0.0273,
-0.0234,
-0.0195,
-0.0156,
-0.0117,
-0.0078,
-0.0039,
0.0000,
]
)

expected_first_head_first_row_first_entry = torch.tensor(
0.0000,
)

assert_expected(
base_mask[0][0][0],
expected_first_head_first_row_first_entry,
rtol=0,
atol=1e-4,
)

assert_expected(
base_mask[num_heads - 1][max_seq_len - 1],
expected_last_head_row,
rtol=0,
atol=1e-4,
)

def test_alibi_mask_non_power_of_2(
self,
max_seq_len,
num_heads_non_power_2,
):
alibi_class = AlibiPositionEmbeddings(
max_seq_len=max_seq_len, num_heads=num_heads_non_power_2
)
base_mask = alibi_class.get_attention_mask(max_seq_len)

# verify mask shape
expected_shape = torch.Size((num_heads_non_power_2, max_seq_len, max_seq_len))
assert_expected(base_mask.shape, expected_shape)

# verify alibi mask components
expected_second_head_last_row = torch.tensor(
[
-7.5000,
-7.0000,
-6.5000,
-6.0000,
-5.5000,
-5.0000,
-4.5000,
-4.0000,
-3.5000,
-3.0000,
-2.5000,
-2.0000,
-1.5000,
-1.0000,
-0.5000,
0.0000,
]
)

expected_third_head_last_row = torch.tensor(
[
-5.3033,
-4.9497,
-4.5962,
-4.2426,
-3.8891,
-3.5355,
-3.1820,
-2.8284,
-2.4749,
-2.1213,
-1.7678,
-1.4142,
-1.0607,
-0.7071,
-0.3536,
0.0000,
]
)

expected_first_head_first_row_first_entry = torch.tensor(
0.0000,
)

assert_expected(
base_mask[0][0][0],
expected_first_head_first_row_first_entry,
rtol=0,
atol=1e-4,
)

# verify 2nd and 3rd head to confirm non power 2 symmetry of slopes
assert_expected(
base_mask[1][max_seq_len - 1],
expected_second_head_last_row,
rtol=0,
atol=1e-4,
)

assert_expected(
base_mask[2][max_seq_len - 1],
expected_third_head_last_row,
rtol=0,
atol=1e-4,
)
108 changes: 107 additions & 1 deletion torchmultimodal/modules/layers/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Tuple
import math
from typing import List, Tuple

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -169,3 +170,108 @@ def forward(self, t: Tensor) -> Tensor:
if self.embed_dim % 2 == 1:
embeddings = nn.functional.pad(embeddings, (0, 1))
return embeddings


class AlibiPositionEmbeddings(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level q: if we not using model forward and mostly using class/static methods, why not just define as a function? Offhand I don't see a reason why this needs to be stateful (it's very possible I'm missing something though)

"""Attention with Linear Biases (ALiBi)

# Softmax(qiKT + m · [-(i - 1), ..., -2, -1, 0]),
where m = fixed specific slope per head

as proposed in:
https://arxiv.org/abs/2108.12409
Train Short, Test Long: Attention with Linear Biases
Enables Input Length Extrapolation

derived from Ofir Press (alibi author) codebase:
https://github.com/ofirpress/attention_with_linear_biases

"""

def __init__(
self,
max_seq_len: int,
num_heads: int,
) -> None:
"""recommended usage: create alibi mask before transformer block loop and integrate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a bit tricky. Kinda similar to RoPE embeddings: integrating this properly will necessitate rethinking some aspects of our transformer implementation. For instance, seems like one assumption here is that our transformer's mask should be float dtype and not bool

Alibi should be applied after the sqrt scaling of the attention values

Example:
before Transformer block loop:
from alibi_embeddings import AlibiPE
self.alibi = AlibiPE(config.max_seq_len, config.num_heads)
pass a reference to the alibi class to each transformer layer
then in forward of transformer layer:
alibi_mask = self.alibi.get_attention_mask(N) # N = seq length of this batch
...
attn = q @ k.transpose( -2, -1)
att *= 1.0 / math.sqrt(k.size(-1))
att += alibi_mask

"""
super().__init__()

self.num_heads = num_heads
self.max_seq_len = max_seq_len

self.causal_mask = self.build_causal_attention_mask(
self.max_seq_len, self.num_heads
)
self.alibi_mask_base = self.build_alibi_mask(self.max_seq_len, self.num_heads)
self.decoder_mask = self.causal_mask + self.alibi_mask_base
self.register_buffer("alibi_mask", self.decoder_mask, persistent=False)

def get_attention_mask(self, curr_seq_len: int) -> torch.Tensor:
"""returns the alibi mask, clipped to the current batch seq len"""
return self.alibi_mask[..., :curr_seq_len, :curr_seq_len]

@classmethod
def build_causal_attention_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw there is also the get_causal_attention_mask utility (you may even be able to use get_extended_attention_mask from the same file in lieu of the repeat, it does broadcast to an extra dim for batch size though)

"""builds a generic causal attention mask"""
causal_mask = torch.triu(
torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1
)
attn_mask = causal_mask.repeat(num_heads, 1, 1)
return attn_mask

@classmethod
def build_alibi_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor:
"""generate the alibi mask by computing a distance bias matrix multiplied by each head's m (slope)"""
distance_bias_matrix = -torch.abs(
torch.arange(seq_len) - torch.arange(seq_len).view(-1, 1)
)
slope_per_head = Tensor(cls.get_slopes(num_heads)).view(-1, 1, 1)
alibi_mask = distance_bias_matrix * slope_per_head
return alibi_mask

@staticmethod
def get_slopes(num_heads: int) -> List[float]:
"""for n heads, a range from (0,1) and is the geometric sequence
that starts at 2^(-8/n) and uses this same value as its ratio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining/documenting the magic numbers 🙂


example: num_heads =4
result: [0.25, 0.0625, 0.015625, 0.00390625]

"""

def get_slopes_power_of_2(n: int) -> List[float]:
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]

if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)

# paper authors note that they only trained models that have 2^a heads for some a.
# This has beneficial properties related to input being power of 2.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what these properties are? Tbh I am confused by this because even if n is a power of 2 some of the ratios will not be rational for n > 8


# Closest power of 2 below is workaround for when num of heads is not power of 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Their method of interpolating is a bit unusual. Maybe explicitly explain that for $num \textunderscore heads=2^N + k$ they are splicing the geometric series with ratio $2^{-\frac{8}{N}}$ with the first $2k$ elements of the geometric series with ratio $2^{-\frac{8}{N+1}}$ (assuming I am even understanding it correctly 😅)

# Slopes are returned in ordered sequence to keep symmetry.

closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))

a = get_slopes_power_of_2(closest_power_of_2)
b = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
: num_heads - closest_power_of_2
]
return [x for pair in zip(b, a) for x in pair] + a[len(b) :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo this is hard to parse. Agree with @daviswer's comment about returning values in order but could we just do sorted(a+b)? (Maybe I'm missing a tricky case.. if so a comment explaining this would suffice instead)