Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 95aeb11

Browse files
author
Ryan Sepassi
committed
Make SRU code Py3 compatible
PiperOrigin-RevId: 192819555
1 parent 6f1152c commit 95aeb11

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1951,7 +1951,8 @@ def sru(x, num_layers=2,
19511951
x = tf.transpose(x, [1, 0, 2]) # Scan assumes time on axis 0.
19521952
initial_state = initial_state or tf.zeros([x_shape[0], x_shape[-1]])
19531953
# SRU state manipulation function.
1954-
def next_state(cur_state, (cur_x_times_one_minus_f, cur_f)):
1954+
def next_state(cur_state, args_tup):
1955+
cur_x_times_one_minus_f, cur_f = args_tup
19551956
return cur_f * cur_state + cur_x_times_one_minus_f
19561957
# Calculate SRU on each layer.
19571958
for i in xrange(num_layers):

0 commit comments

Comments
 (0)