Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit da0bc49

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
internal merge
PiperOrigin-RevId: 184210759
1 parent f9c859a commit da0bc49

File tree

2 files changed

+52
-52
lines changed

2 files changed

+52
-52
lines changed

tensor2tensor/models/cycle_gan.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,39 @@
2929
import tensorflow as tf
3030

3131

32-
def reconstruct_loss(x, gt, hparams, reuse=None):
33-
pred = tf.layers.dense(x, hparams.vocab_size, name="softmax", reuse=reuse)
34-
xent, w = common_layers.padded_cross_entropy(pred, gt, 0.0)
35-
return xent / w
36-
37-
3832
def discriminator(x, compress, hparams, name, reuse=None):
3933
with tf.variable_scope(name, reuse=reuse):
4034
x = tf.stop_gradient(2 * x) - x # Reverse gradient.
4135
if compress:
42-
x = transformer_vae.compress(x, None, hparams, "compress")
36+
x = transformer_vae.compress(x, None, False, hparams, "compress")
4337
else:
44-
x = transformer_vae.residual_conv(x, 1, hparams, "compress_rc")
38+
x = transformer_vae.residual_conv(x, 1, 3, hparams, "compress_rc")
4539
y = tf.reduce_mean(x, axis=1)
4640
return tf.tanh(tf.layers.dense(y, 1, name="reduce"))
4741

4842

49-
def discriminate_loss(x, y, compress, hparams, name):
43+
def generator(x, hparams, name, reuse=False):
44+
with tf.variable_scope(name, reuse=reuse):
45+
return transformer_vae.residual_conv(x, 1, 3, hparams, "generator")
46+
47+
48+
def lossfn(real_input, fake_input, compress, hparams, lsgan, name):
49+
eps = 1e-12
5050
with tf.variable_scope(name):
51-
d1 = discriminator(x, compress, hparams, "discriminator")
52-
d2 = discriminator(y, compress, hparams, "discriminator", reuse=True)
53-
dloss = tf.reduce_mean(tf.abs(d1 - d2))
54-
return - dloss
51+
d1 = discriminator(real_input, compress, hparams, "discriminator")
52+
d2 = discriminator(fake_input, compress, hparams, "discriminator",
53+
reuse=True)
54+
if lsgan:
55+
dloss = tf.reduce_mean(
56+
tf.squared_difference(d1, 0.9)) + tf.reduce_mean(tf.square(d2))
57+
gloss = tf.reduce_mean(tf.squared_difference(d2, 0.9))
58+
loss = (dloss + gloss)/2
59+
else: # cross_entropy
60+
dloss = -tf.reduce_mean(
61+
tf.log(d1 + eps)) - tf.reduce_mean(tf.log(1 - d2 + eps))
62+
gloss = -tf.reduce_mean(tf.log(d2 + eps))
63+
loss = (dloss + gloss)/2
64+
return loss
5565

5666

5767
def split_on_batch(x):
@@ -71,48 +81,37 @@ def cycle_gan_internal(inputs, targets, _, hparams):
7181
targets_orig, hparams.vocab_size, hparams.hidden_size,
7282
"embed", reuse=True)
7383

74-
# Split the batch into input-input and target-target parts.
75-
inputs1, _ = split_on_batch(inputs)
76-
_, targets2 = split_on_batch(targets)
77-
78-
# Define F and G, called inp2tgt and tgt2inp here.
79-
def inp2tgt(x, reuse=False):
80-
return transformer_vae.residual_conv(x, 1, hparams, "inp2tgt", reuse)
81-
def tgt2inp(x, reuse=False):
82-
return transformer_vae.residual_conv(x, 1, hparams, "tgt2inp", reuse)
83-
84-
# Input-input part.
85-
inp1_tgt = inp2tgt(inputs1)
86-
inp1_back = tgt2inp(inp1_tgt)
84+
x, _ = split_on_batch(inputs)
85+
_, y = split_on_batch(targets)
8786

88-
# Target-target part.
89-
tgt2_inp = tgt2inp(targets2, reuse=True)
90-
tgt2_back = inp2tgt(tgt2_inp, reuse=True)
87+
# Y --> X
88+
y_fake = generator(y, hparams, "Fy", reuse=False)
89+
y_to_x_loss = lossfn(y, y_fake, True, hparams, True, "YtoX")
9190

92-
# Reconstruction losses.
93-
inp1_orig, _ = split_on_batch(inputs_orig)
94-
_, tgt2_orig = split_on_batch(targets_orig)
95-
inp1_loss = reconstruct_loss(
96-
inp1_back, tf.squeeze(inp1_orig, axis=3), hparams)
97-
tgt2_loss = reconstruct_loss(
98-
tgt2_back, tf.squeeze(tgt2_orig, axis=3), hparams, reuse=True)
91+
# X --> Y
92+
x_fake = generator(x, hparams, "Gx", reuse=False)
93+
x_to_y_loss = lossfn(y, x_fake, True, hparams, True, "XtoY")
9994

100-
# Discriminator losses.
101-
dloss1 = discriminate_loss(inputs1, tgt2_inp, True, hparams, "inp_disc")
102-
dloss2 = discriminate_loss(targets2, inp1_tgt, True, hparams, "tgt_disc")
95+
# Cycle-Consistency
96+
y_fake_ = generator(y_fake, hparams, "Gx", reuse=True)
97+
x_fake_ = generator(x_fake, hparams, "Fy", reuse=True)
98+
x_to_x_loss = hparams.cycle_loss_multiplier1 * tf.reduce_mean(
99+
tf.abs(x_fake_ - x))
100+
y_to_y_loss = hparams.cycle_loss_multiplier2 * tf.reduce_mean(
101+
tf.abs(y_fake_ - y))
102+
cycloss = x_to_x_loss + y_to_y_loss
103103

104-
# Reconstruct targets from inputs.
105-
tgt = inp2tgt(inputs, reuse=True)
106-
tgt = tf.layers.dense(tgt, hparams.vocab_size, name="softmax", reuse=True)
104+
sample_generated = generator(inputs, hparams, "Gx", reuse=True)
105+
sample_generated = tf.layers.dense(
106+
sample_generated, hparams.vocab_size, name="softmax", reuse=None)
107+
sample_generated = tf.stop_gradient(
108+
tf.expand_dims(sample_generated, axis=2))
107109

108-
# We use the reconstruction only for tracking progress, no gradients here!
109-
tgt = tf.stop_gradient(tf.expand_dims(tgt, axis=2))
110+
losses = {"cycloss": cycloss,
111+
"y_to_x_loss": y_to_x_loss,
112+
"x_to_y_loss": x_to_y_loss}
110113

111-
losses = {"input_input": hparams.cycle_loss_multiplier * inp1_loss,
112-
"target_target": hparams.cycle_loss_multiplier * tgt2_loss,
113-
"input_disc": dloss1,
114-
"target_disc": dloss2}
115-
return tgt, losses
114+
return sample_generated, losses
116115

117116

118117
@registry.register_model
@@ -135,6 +134,7 @@ def cycle_gan_small():
135134
hparams.learning_rate = 0.05
136135
hparams.kl_warmup_steps = 5000
137136
hparams.learning_rate_warmup_steps = 3000
138-
hparams.add_hparam("vocab_size", 32) # Vocabulary size, need to set here.
139-
hparams.add_hparam("cycle_loss_multiplier", 2.0)
137+
hparams.add_hparam("vocab_size", 66) # Vocabulary size, need to set here.
138+
hparams.add_hparam("cycle_loss_multiplier1", 10.0)
139+
hparams.add_hparam("cycle_loss_multiplier2", 10.0)
140140
return hparams

tensor2tensor/utils/devices.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def data_parallelism(daisy_chain_variables=True,
7777
worker_job="/job:localhost",
7878
no_data_parallelism=False):
7979
"""See data_parallelism_from_flags."""
80-
tf.logging.info("schuedule=%s" % schedule)
80+
tf.logging.info("schedule=%s" % schedule)
8181
tf.logging.info("worker_gpu=%s" % worker_gpu)
8282
tf.logging.info("sync=%s" % sync)
8383
def _ps_replicas(all_workers=False):

0 commit comments

Comments
 (0)