Minimum working example of training hyperprior; Weights not updating #138
-
I am trying to create a dummy example to train the hyperprior of an entropy model. I used bls2017.py as my reference. The issue seems to be that the dummy model doesn't see the trainable variables in the prior. Any thoughts on what I am missing? My environment:
My dummy example: import tensorflow as tf
import tensorflow_compression as tfc
class DummyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.prior = tfc.NoisyDeepFactorized()
self.build((None, 10))
def call(self, inputs):
entropy_model = tfc.ContinuousBatchedEntropyModel(self.prior, coding_rank=1, compression=False)
_, bits = entropy_model(inputs, training=True)
return tf.reduce_mean(bits)
model = DummyModel()
model.compile(tf.keras.optimizers.Adam(0.1), tf.keras.losses.MeanAbsoluteError())
x_train = tf.random.normal([10**6, 10], mean=5.0, stddev=0.5)
y_train = tf.zeros(shape=(x_train.shape[0],1))
init_vars = [v.numpy().mean() for v in model.prior.trainable_variables]
print('Prior weights:', len(model.prior.trainable_variables))
print('Model weights:', len(model.trainable_weights))
history = model.fit(x_train, y_train, batch_size=1024, epochs=2)
print(history.history)
unchanged = init_vars == [v.numpy().mean() for v in model.prior.trainable_variables]
print('Prior weights unchanged?', unchanged) Output:
I am aiming to see the training step take the gradient of the average bits estimate (i.e. the MAE loss relative to a 0 target) w.r.t. the prior weights and then apply some update to those weights. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi, you are experiencing a bug in older TF releases. |
Beta Was this translation helpful? Give feedback.
Hi, you are experiencing a bug in older TF releases.
tf.keras.Model
classes didn't collect trainable variables from all nested objects that inherit fromtf.Module
, only from ones that inherit fromtf.keras.layers.Layer
.Distribution
objects would fall in this category. This was fixed in a later TF version. I think it was fixed in 2.5. I'd recommend using the latest version (2.8; 2.9 should probably be released end of this week). If that's not possible, there is a workaround, check out this commit.