Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 254bdfc

Browse files
Niki ParmarRyan Sepassi
authored andcommitted
Move attention util functions
PiperOrigin-RevId: 179499133
1 parent ae62ed6 commit 254bdfc

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

tensor2tensor/visualization/attention.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@
4545

4646

4747
def show(inp_text, out_text, enc_atts, dec_atts, encdec_atts):
48+
enc_att, dec_att, encdec_att = (resize(enc_atts),
49+
resize(dec_atts), resize(encdec_atts))
4850
attention = _get_attention(
49-
inp_text, out_text, enc_atts, dec_atts, encdec_atts)
51+
inp_text, out_text, enc_att, dec_att, encdec_att)
5052
att_json = json.dumps(attention)
5153
_show_attention(att_json)
5254

@@ -57,6 +59,23 @@ def _show_attention(att_json):
5759
display.display(display.Javascript(vis_js))
5860

5961

62+
def resize(att_mat, max_length=30):
63+
"""Normalize attention matrices and reshape as necessary."""
64+
layer_mats = []
65+
for att in att_mat:
66+
# Sum across different heads.
67+
att = att[ :, :max_length, :max_length]
68+
row_sums = np.sum(att, axis=0)
69+
# Normalize
70+
layer_mat = att / row_sums[np.newaxis, :]
71+
lsh = layer_mat.shape
72+
# Add extra batch dim for viz code to work.
73+
if len(np.shape(lsh)) == 3:
74+
layer_mat = np.reshape(layer_mat, (1, lsh[0], lsh[1], lsh[2]))
75+
layer_mats.append(layer_mat)
76+
return layer_mats
77+
78+
6079
def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts):
6180
"""Compute representation of the attention ready for the d3 visualization.
6281

0 commit comments

Comments
 (0)