Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 176148c

Browse files
T2T Teamcopybara-github
authored andcommitted
Fix attention rng mismatch between forward and reverse direction
PiperOrigin-RevId: 272707157
1 parent 9f29518 commit 176148c

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

tensor2tensor/trax/layers/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,8 @@ def new_params_and_state(self, input_shape, input_dtype, rng):
344344
class BaseCausalAttention(base.Layer):
345345
"""Base class for variants of causal self-attention."""
346346

347-
def __init__(self):
347+
def __init__(self, mode='train'):
348+
del mode
348349
super(BaseCausalAttention, self).__init__(n_inputs=3)
349350

350351
def forward(self, inputs, params=(), state=(), rng=None, **kwargs):

tensor2tensor/trax/layers/reversible.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def reverse(self, output, params=(), state=(), **kwargs):
101101
rngs = backend.random.split(rng, self._n_layers)
102102

103103
layer_val = output
104-
for layer, p, s, rng in reversed(zip(self.sublayers,
105-
params, state, rngs)):
104+
for layer, p, s, rng in reversed(list(zip(self.sublayers,
105+
params, state, rngs))):
106106
layer_val = layer.reverse(layer_val, p, s, rng=rng, **kwargs)
107107

108108
return layer_val
@@ -116,8 +116,8 @@ def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs):
116116
layer_val = output
117117
layer_ct = ct
118118
params_ct = []
119-
for layer, p, s, rng in reversed(zip(self.sublayers,
120-
params, state, rngs)):
119+
for layer, p, s, rng in reversed(list(zip(self.sublayers,
120+
params, state, rngs))):
121121
layer_val, layer_ct = layer.reverse_and_grad(
122122
layer_val, layer_ct, p, s, rng=rng, **kwargs)
123123
layer_ct, p_ct = layer_ct

tensor2tensor/trax/models/research/reformer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,18 @@ def __init__(self, attention):
254254
super(ApplyAttentionWrapper, self).__init__(attention, [], [])
255255
self.attention = attention
256256

257-
def forward_and_backward(self, inputs, ct, **kwargs):
257+
def forward_and_backward(self, inputs, ct, rng=None, **kwargs):
258258
# Simultaneous forward pass and backprop through the attention mechanism.
259259
qkv = inputs[:3]
260260
passthrough = inputs[3:]
261261
out_ct = ct[0]
262262
passthrough_ct = ct[1:]
263+
if rng is not None:
264+
# Adjust RNG to match the forward pass.
265+
rng = backend.random.split(rng, self._n_layers)[0]
263266

264-
out, qkv_ct = self.attention.forward_and_backward(qkv, out_ct, **kwargs)
267+
out, qkv_ct = self.attention.forward_and_backward(
268+
qkv, out_ct, rng=rng, **kwargs)
265269
return (out,) + passthrough, qkv_ct + passthrough_ct
266270

267271

tensor2tensor/trax/models/research/reformer_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,43 @@
2121

2222
from absl.testing import absltest
2323
from absl.testing import parameterized
24+
import jax
25+
import numpy as onp
26+
27+
from tensor2tensor.trax import backend
2428
from tensor2tensor.trax import layers as tl
29+
from tensor2tensor.trax.backend import numpy as np
2530
from tensor2tensor.trax.models.research import reformer
2631

2732

33+
class PoisonOnRNGMismatchAttention(tl.BaseCausalAttention):
34+
"""Fills gradients with NaNs if reverse rng does not match forward rng."""
35+
36+
# pylint: disable=protected-access
37+
def forward_and_backward(self, inputs, ct, rng=None, **kwargs):
38+
assert backend.get_name() == 'jax', (
39+
'JAX backend is required to use forward_and_backward.')
40+
41+
if ct is not None and tl.Layer._STASH_OUT is not None:
42+
recovered_rng = tl.Layer._STASH_OUT.pop(self)
43+
is_same = (rng[0] == recovered_rng[0]) & (rng[1] == recovered_rng[1])
44+
is_same = is_same.astype(np.float32)
45+
# Divides by zero if rngs are not the same, which results in NaNs.
46+
inputs = (inputs[0] / is_same, inputs[1] / is_same, inputs[2] / is_same)
47+
48+
def _do_forward(x): # pylint: disable=invalid-name
49+
res, _ = self.forward(x, rng=rng, **kwargs)
50+
return res
51+
output, vjpfun = jax.vjp(_do_forward, inputs)
52+
return output, vjpfun(ct)[0]
53+
54+
def forward(self, inputs, params=(), state=(), rng=None, **kwargs):
55+
if tl.Layer._STASH_IN is not None:
56+
tl.Layer._STASH_IN[self] = rng
57+
return inputs[2], state
58+
# pylint: enable=protected-access
59+
60+
2861
class ReformerTest(parameterized.TestCase):
2962

3063
def test_reformer_lm_forward_shape(self):
@@ -39,6 +72,33 @@ def test_reformer_lm_forward_shape(self):
3972
model, tuple(input_shape), integer_inputs=True)
4073
self.assertEqual(((1, 8, 16), (1, 8, 16)), final_shape)
4174

75+
def test_reformer_rng_consistency(self):
76+
with backend.use_backend('jax'):
77+
vocab_size = 16
78+
batch_size = 1
79+
input_shape = ((batch_size, 8), (batch_size, 8))
80+
model = reformer.ReformerLM(
81+
vocab_size, d_model=32, d_ff=64,
82+
d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2,
83+
max_len=16, n_chunks=2, n_attention_chunks=1, mode='train',
84+
attention_type=PoisonOnRNGMismatchAttention)
85+
86+
rng = backend.random.get_prng(0)
87+
params, state = model.initialize_once(
88+
input_shape, (np.int32, np.int32), rng)
89+
90+
def dummy_loss_fn(params):
91+
inputs = (np.zeros(input_shape[0], dtype=np.int32),) * 2
92+
output = model(inputs, params=params, state=state, rng=rng)
93+
dummy_loss = backend.numpy.sum(output[0])
94+
return dummy_loss
95+
96+
grad_fn = backend.grad(dummy_loss_fn)
97+
grads = grad_fn(params)
98+
# PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
99+
for grad in jax.tree_util.tree_leaves(grads):
100+
assert onp.all(onp.isfinite(grad))
101+
42102

43103
if __name__ == '__main__':
44104
absltest.main()

0 commit comments

Comments
 (0)