Skip to content

Commit 2912c8c

Browse files
committed
Add an option where the boundary 'words' of a sentence are used as the sinks in the constituency relative attention module
Add a unit test which checks that it still works with these flags, although we have not yet checked that it is doing anything useful
1 parent 5cb52f3 commit 2912c8c

File tree

4 files changed

+24
-4
lines changed

4 files changed

+24
-4
lines changed

stanza/models/common/relative_attn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, d_model, num_heads, window=8, dropout=0.2, reverse=False, d_o
4747

4848
self.reverse = reverse
4949

50-
def forward(self, x):
50+
def forward(self, x, sink=None):
5151
# x.shape == (batch_size, seq_len, d_model)
5252
batch_size, seq_len, d_model = x.shape
5353
if d_model != self.d_model:
@@ -66,7 +66,10 @@ def forward(self, x):
6666
# could keep a parameter to train sinks, but as it turns out,
6767
# the position vectors just overlap that parameter space anyway
6868
# generally the model trains the sinks to zero if we do that
69-
sink = torch.zeros((batch_size, self.num_sinks, d_model), dtype=x.dtype, device=x.device)
69+
if sink is None:
70+
sink = torch.zeros((batch_size, self.num_sinks, d_model), dtype=x.dtype, device=x.device)
71+
else:
72+
sink = sink.expand(batch_size, self.num_sinks, d_model)
7073
x = torch.cat((sink, x), axis=1)
7174

7275
# k.shape = (batch_size, num_heads, d_head, seq_len + num_sinks)

stanza/models/constituency/lstm_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -824,9 +824,15 @@ def map_word(word):
824824
rattn_inputs = [[x] for x in all_word_inputs]
825825

826826
if self.rel_attn_forward is not None:
827-
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
827+
if self.args['rattn_use_endpoint_sinks']:
828+
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0), x[0][0]).squeeze(0)] for x in rattn_inputs]
829+
else:
830+
rattn_inputs = [x + [self.rel_attn_forward(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
828831
if self.rel_attn_reverse is not None:
829-
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
832+
if self.args['rattn_use_endpoint_sinks']:
833+
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0), x[0][-1]).squeeze(0)] for x in rattn_inputs]
834+
else:
835+
rattn_inputs = [x + [self.rel_attn_reverse(x[0].unsqueeze(0)).squeeze(0)] for x in rattn_inputs]
830836

831837
if self.args['rattn_cat']:
832838
all_word_inputs = [torch.cat(x, axis=1) for x in rattn_inputs]

stanza/models/constituency_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ def build_argparse():
737737
parser.add_argument('--rattn_cat', default=True, action='store_true', help='Stack the rattn layers instead of adding them')
738738
parser.add_argument('--rattn_dim', default=200, type=int, help='Dimension of the rattn output when cat')
739739
parser.add_argument('--rattn_sinks', default=0, type=int, help='Number of attention sink tokens to learn')
740+
parser.add_argument('--rattn_use_endpoint_sinks', default=False, action='store_true', help='Use the endpoints of the sentences as sinks')
740741

741742
parser.add_argument('--log_norms', default=False, action='store_true', help='Log the parameters norms while training. A very noisy option')
742743
parser.add_argument('--log_shapes', default=False, action='store_true', help='Log the parameters shapes at the beginning')

stanza/tests/constituency/test_lstm_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,16 @@ def test_relative_attention_cat_sinks(pretrain_file):
455455
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_cat', '--rattn_sinks', '2')
456456
run_forward_checks(model)
457457

458+
def test_relative_attention_endpoint_sinks(pretrain_file):
459+
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '1')
460+
run_forward_checks(model)
461+
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '1')
462+
run_forward_checks(model)
463+
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_window', '2', '--rattn_sinks', '2')
464+
run_forward_checks(model)
465+
model = build_model(pretrain_file, '--no_use_lattn', '--use_rattn', '--rattn_heads', '10', '--rattn_use_endpoint_sinks', '--rattn_sinks', '2')
466+
run_forward_checks(model)
467+
458468
def test_lstm_tree_forward(pretrain_file):
459469
"""
460470
Test the LSTM_TREE forward pass

0 commit comments

Comments
 (0)