@@ -40,26 +40,33 @@ def body(self, features):
40
40
# Concat frames and down-stride.
41
41
cur_frame = tf .to_float (features ["inputs" ])
42
42
prev_frame = tf .to_float (features ["inputs_prev" ])
43
- frames = tf .concat ([cur_frame , prev_frame ], axis = - 1 )
44
- x = tf .layers .conv2d (frames , filters , kernel2 , activation = tf .nn .relu ,
45
- strides = (2 , 2 ), padding = "SAME" )
43
+ x = tf .concat ([cur_frame , prev_frame ], axis = - 1 )
44
+ for _ in xrange (hparams .num_compress_steps ):
45
+ x = tf .layers .conv2d (x , filters , kernel2 , activation = common_layers .belu ,
46
+ strides = (2 , 2 ), padding = "SAME" )
47
+ x = common_layers .layer_norm (x )
48
+ filters *= 2
46
49
# Add embedded action.
47
- action = tf .reshape (features ["action" ], [- 1 , 1 , 1 , filters ])
48
- x = tf .concat ([x , action + tf .zeros_like (x )], axis = - 1 )
50
+ action = tf .reshape (features ["action" ], [- 1 , 1 , 1 , hparams .hidden_size ])
51
+ zeros = tf .zeros (common_layers .shape_list (x )[:- 1 ] + [hparams .hidden_size ])
52
+ x = tf .concat ([x , action + zeros ], axis = - 1 )
49
53
50
54
# Run a stack of convolutions.
51
55
for i in xrange (hparams .num_hidden_layers ):
52
56
with tf .variable_scope ("layer%d" % i ):
53
- y = tf .layers .conv2d (x , 2 * filters , kernel1 , activation = tf . nn . relu ,
57
+ y = tf .layers .conv2d (x , filters , kernel1 , activation = common_layers . belu ,
54
58
strides = (1 , 1 ), padding = "SAME" )
55
59
if i == 0 :
56
60
x = y
57
61
else :
58
62
x = common_layers .layer_norm (x + y )
59
63
# Up-convolve.
60
- x = tf .layers .conv2d_transpose (
61
- x , filters , kernel2 , activation = tf .nn .relu ,
62
- strides = (2 , 2 ), padding = "SAME" )
64
+ for _ in xrange (hparams .num_compress_steps ):
65
+ filters //= 2
66
+ x = tf .layers .conv2d_transpose (
67
+ x , filters , kernel2 , activation = common_layers .belu ,
68
+ strides = (2 , 2 ), padding = "SAME" )
69
+ x = common_layers .layer_norm (x )
63
70
64
71
# Reward prediction.
65
72
reward_pred_h1 = tf .reduce_mean (x , axis = [1 , 2 ], keep_dims = True )
@@ -78,7 +85,7 @@ def basic_conv():
78
85
hparams = common_hparams .basic_params1 ()
79
86
hparams .hidden_size = 64
80
87
hparams .batch_size = 8
81
- hparams .num_hidden_layers = 2
88
+ hparams .num_hidden_layers = 3
82
89
hparams .optimizer = "Adam"
83
90
hparams .learning_rate_constant = 0.0002
84
91
hparams .learning_rate_warmup_steps = 500
@@ -87,6 +94,7 @@ def basic_conv():
87
94
hparams .initializer = "uniform_unit_scaling"
88
95
hparams .initializer_gain = 1.0
89
96
hparams .weight_decay = 0.0
97
+ hparams .add_hparam ("num_compress_steps" , 2 )
90
98
return hparams
91
99
92
100
0 commit comments