|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2017 The Tensor2Tensor Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | + |
| 17 | +"""Creates a RevNet with the bottleneck residual function. |
| 18 | +
|
| 19 | +Implements the following equations described in the RevNet paper: |
| 20 | +y1 = x1 + f(x2) |
| 21 | +y2 = x2 + g(y1) |
| 22 | +
|
| 23 | +However, in practice, the authors use the following equations to downsample |
| 24 | +tensors inside a RevNet block: |
| 25 | +
|
| 26 | +y1 = h(x1) + f(x2) |
| 27 | +y2 = h(x2) + g(y1) |
| 28 | +
|
| 29 | +In this case, h is the downsampling function used to change number of channels. |
| 30 | +
|
| 31 | +These modified equations are evident in the authors' code online: |
| 32 | +https://github.com/renmengye/revnet-public |
| 33 | +
|
| 34 | +For reference, the original paper can be found here: |
| 35 | +https://arxiv.org/pdf/1707.04585.pdf |
| 36 | +""" |
| 37 | + |
| 38 | +# Dependency imports |
| 39 | + |
| 40 | +from tensor2tensor.layers import common_hparams |
| 41 | +from tensor2tensor.layers import rev_block |
| 42 | +from tensor2tensor.utils import registry |
| 43 | +from tensor2tensor.utils import t2t_model |
| 44 | + |
| 45 | +import tensorflow as tf |
| 46 | + |
| 47 | +CONFIG = {'2d': {'conv': tf.layers.conv2d, |
| 48 | + 'max_pool': tf.layers.max_pooling2d, |
| 49 | + 'avg_pool': tf.layers.average_pooling2d, |
| 50 | + 'split_axis': 3, |
| 51 | + 'reduction_dimensions': [1, 2] |
| 52 | + }, |
| 53 | + '3d': {'conv': tf.layers.conv3d, |
| 54 | + 'max_pool': tf.layers.max_pooling3d, |
| 55 | + 'avg_pool': tf.layers.average_pooling2d, |
| 56 | + 'split_axis': 4, |
| 57 | + 'reduction_dimensions': [1, 2, 3] |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + |
| 62 | +def f(x, depth1, depth2, dim='2d', first_batch_norm=True, layer_stride=1, |
| 63 | + training=True, padding='SAME'): |
| 64 | + """Applies bottleneck residual function for 104-layer RevNet. |
| 65 | +
|
| 66 | + Args: |
| 67 | + x: input tensor |
| 68 | + depth1: Number of output channels for the first and second conv layers. |
| 69 | + depth2: Number of output channels for the third conv layer. |
| 70 | + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. |
| 71 | + first_batch_norm: Whether to keep the first batch norm layer or not. |
| 72 | + Typically used in the first RevNet block. |
| 73 | + layer_stride: Stride for the first conv filter. Note that this particular |
| 74 | + 104-layer RevNet architecture only varies the stride for the first conv |
| 75 | + filter. The stride for the second conv filter is always set to 1. |
| 76 | + training: True for train phase, False for eval phase. |
| 77 | + padding: Padding for each conv layer. |
| 78 | +
|
| 79 | + Returns: |
| 80 | + Output tensor after applying residual function for 104-layer RevNet. |
| 81 | + """ |
| 82 | + conv = CONFIG[dim]['conv'] |
| 83 | + with tf.variable_scope('f'): |
| 84 | + if first_batch_norm: |
| 85 | + net = tf.layers.batch_normalization(x, training=training) |
| 86 | + net = tf.nn.relu(net) |
| 87 | + else: |
| 88 | + net = x |
| 89 | + net = conv(net, depth1, 1, strides=layer_stride, |
| 90 | + padding=padding, activation=None) |
| 91 | + |
| 92 | + net = tf.layers.batch_normalization(net, training=training) |
| 93 | + net = tf.nn.relu(net) |
| 94 | + net = conv(net, depth1, 3, strides=1, |
| 95 | + padding=padding, activation=None) |
| 96 | + |
| 97 | + net = tf.layers.batch_normalization(net, training=training) |
| 98 | + net = tf.nn.relu(net) |
| 99 | + net = conv(net, depth2, 1, strides=1, |
| 100 | + padding=padding, activation=None) |
| 101 | + return net |
| 102 | + |
| 103 | + |
| 104 | +def h(x, output_channels, dim='2d', layer_stride=1, scope='h'): |
| 105 | + """Downsamples 'x' using a 1x1 convolution filter and a chosen stride. |
| 106 | +
|
| 107 | + Args: |
| 108 | + x: input tensor of size [N, H, W, C] |
| 109 | + output_channels: Desired number of output channels. |
| 110 | + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. |
| 111 | + layer_stride: What stride to use. Usually 1 or 2. |
| 112 | + scope: Optional variable scope for the h function. |
| 113 | +
|
| 114 | + This function uses a 1x1 convolution filter and a chosen stride to downsample |
| 115 | + the input tensor x. |
| 116 | +
|
| 117 | + Returns: |
| 118 | + A downsampled tensor of size [N, H/2, W/2, output_channels] if layer_stride |
| 119 | + is 2, else returns a tensor of size [N, H, W, output_channels] if |
| 120 | + layer_stride is 1. |
| 121 | + """ |
| 122 | + conv = CONFIG[dim]['conv'] |
| 123 | + with tf.variable_scope(scope): |
| 124 | + x = conv(x, output_channels, 1, strides=layer_stride, padding='SAME', |
| 125 | + activation=None) |
| 126 | + return x |
| 127 | + |
| 128 | + |
| 129 | +def init(images, num_channels, dim='2d', training=True, scope='init'): |
| 130 | + """Standard ResNet initial block used as first RevNet block. |
| 131 | +
|
| 132 | + Args: |
| 133 | + images: [N, H, W, 3] tensor of input images to the model. |
| 134 | + num_channels: Output depth of convolutional layer in initial block. |
| 135 | + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. |
| 136 | + training: True for train phase, False for eval phase. |
| 137 | + scope: Optional scope for the init block. |
| 138 | +
|
| 139 | + Returns: |
| 140 | + Two [N, H, W, C] output activations from input images. |
| 141 | + """ |
| 142 | + conv = CONFIG[dim]['conv'] |
| 143 | + pool = CONFIG[dim]['max_pool'] |
| 144 | + with tf.variable_scope(scope): |
| 145 | + net = conv(images, num_channels, 7, strides=2, |
| 146 | + padding='SAME', activation=None) |
| 147 | + net = tf.layers.batch_normalization(net, training=training) |
| 148 | + net = tf.nn.relu(net) |
| 149 | + net = pool(net, pool_size=3, strides=2) |
| 150 | + x1, x2 = tf.split(net, 2, axis=CONFIG[dim]['split_axis']) |
| 151 | + return x1, x2 |
| 152 | + |
| 153 | + |
| 154 | +def unit(x1, x2, block_num, depth1, depth2, num_layers, dim='2d', |
| 155 | + first_batch_norm=True, stride=1, training=True): |
| 156 | + """Implements bottleneck RevNet unit from authors' RevNet-104 architecture. |
| 157 | +
|
| 158 | + Args: |
| 159 | + x1: [N, H, W, C] tensor of network activations. |
| 160 | + x2: [N, H, W, C] tensor of network activations. |
| 161 | + block_num: integer ID of block |
| 162 | + depth1: First depth in bottleneck residual unit. |
| 163 | + depth2: Second depth in bottleneck residual unit. |
| 164 | + num_layers: Number of layers in the RevNet block. |
| 165 | + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. |
| 166 | + first_batch_norm: Whether to keep the first batch norm layer or not. |
| 167 | + Typically used in the first RevNet block. |
| 168 | + stride: Stride for the residual function. |
| 169 | + training: True for train phase, False for eval phase. |
| 170 | +
|
| 171 | + Returns: |
| 172 | + Two [N, H, W, C] output activation tensors. |
| 173 | + """ |
| 174 | + scope_name = 'unit_%d' % block_num |
| 175 | + with tf.variable_scope(scope_name): |
| 176 | + # Manual implementation of downsampling |
| 177 | + with tf.variable_scope('downsampling'): |
| 178 | + with tf.variable_scope('x1'): |
| 179 | + hx1 = h(x1, depth2, dim=dim, layer_stride=stride) |
| 180 | + fx2 = f(x2, depth1, depth2, dim=dim, layer_stride=stride, |
| 181 | + first_batch_norm=first_batch_norm, training=training) |
| 182 | + x1 = hx1 + fx2 |
| 183 | + with tf.variable_scope('x2'): |
| 184 | + hx2 = h(x2, depth2, dim=dim, layer_stride=stride) |
| 185 | + fx1 = f(x1, depth1, depth2, dim=dim, training=training) |
| 186 | + x2 = hx2 + fx1 |
| 187 | + |
| 188 | + # Full block using memory-efficient rev_block implementation. |
| 189 | + with tf.variable_scope('full_block'): |
| 190 | + residual_func = lambda x: f(x, depth1, depth2, dim=dim, training=training) |
| 191 | + x1, x2 = rev_block.rev_block(x1, x2, |
| 192 | + residual_func, |
| 193 | + residual_func, |
| 194 | + num_layers=num_layers) |
| 195 | + return x1, x2 |
| 196 | + |
| 197 | + |
| 198 | +def final_block(x1, x2, dim='2d', training=True, scope='final_block'): |
| 199 | + """Converts activations from last RevNet block to pre-logits. |
| 200 | +
|
| 201 | + Args: |
| 202 | + x1: [NxHxWxC] tensor of network activations. |
| 203 | + x2: [NxHxWxC] tensor of network activations. |
| 204 | + dim: '2d' if 2-dimensional, '3d' if 3-dimensional. |
| 205 | + training: True for train phase, False for eval phase. |
| 206 | + scope: Optional variable scope for the final block. |
| 207 | +
|
| 208 | + Returns: |
| 209 | + [N, hidden_dim] pre-logits tensor from activations x1 and x2. |
| 210 | + """ |
| 211 | + |
| 212 | + # Final batch norm and relu |
| 213 | + with tf.variable_scope(scope): |
| 214 | + y = tf.concat([x1, x2], axis=CONFIG[dim]['split_axis']) |
| 215 | + y = tf.layers.batch_normalization(y, training=training) |
| 216 | + y = tf.nn.relu(y) |
| 217 | + |
| 218 | + # Global average pooling |
| 219 | + net = tf.reduce_mean(y, CONFIG[dim]['reduction_dimensions'], |
| 220 | + name='final_pool', keep_dims=True) |
| 221 | + |
| 222 | + return net |
| 223 | + |
| 224 | + |
| 225 | +def revnet104(inputs, hparams, reuse=None): |
| 226 | + """Uses Tensor2Tensor memory optimized RevNet block to build a RevNet. |
| 227 | +
|
| 228 | + Args: |
| 229 | + inputs: [NxHxWx3] tensor of input images to the model. |
| 230 | + hparams: HParams object that contains the following parameters, |
| 231 | + in addition to the parameters contained in the basic_params1() object in |
| 232 | + the common_hparams module: |
| 233 | + num_channels_first - A Python list where each element represents the |
| 234 | + depth of the first and third convolutional layers in the bottleneck |
| 235 | + residual unit for a given block. |
| 236 | + num_channels_second - A Python list where each element represents the |
| 237 | + depth of the second convolutional layer in the bottleneck residual |
| 238 | + unit for a given block. |
| 239 | + num_layers_per_block - A Python list containing the number of RevNet |
| 240 | + layers for each block. |
| 241 | + first_batch_norm - A Python list containing booleans representing the |
| 242 | + presence of a batch norm layer at the beginning of a given block. |
| 243 | + strides - A Python list containing integers representing the stride of |
| 244 | + the residual function for each block. |
| 245 | + num_channels_init_block - An integer representing the number of channels |
| 246 | + for the convolutional layer in the initial block. |
| 247 | + dimension - A string (either "2d" or "3d") that decides if the RevNet is |
| 248 | + 2-dimensional or 3-dimensional. |
| 249 | + reuse: Whether to reuse the default variable scope. |
| 250 | +
|
| 251 | + Returns: |
| 252 | + [batch_size, hidden_dim] pre-logits tensor from the bottleneck RevNet. |
| 253 | + """ |
| 254 | + training = hparams.mode == tf.estimator.ModeKeys.TRAIN |
| 255 | + with tf.variable_scope('RevNet104', reuse=reuse): |
| 256 | + x1, x2 = init(inputs, |
| 257 | + num_channels=hparams.num_channels_init_block, |
| 258 | + dim=hparams.dim, |
| 259 | + training=training) |
| 260 | + for block_num in range(1, len(hparams.num_layers_per_block)): |
| 261 | + block = {'depth1': hparams.num_channels_first[block_num], |
| 262 | + 'depth2': hparams.num_channels_second[block_num], |
| 263 | + 'num_layers': hparams.num_layers_per_block[block_num], |
| 264 | + 'first_batch_norm': hparams.first_batch_norm[block_num], |
| 265 | + 'stride': hparams.strides[block_num]} |
| 266 | + x1, x2 = unit(x1, x2, block_num, dim=hparams.dim, training=training, |
| 267 | + **block) |
| 268 | + pre_logits = final_block(x1, x2, dim=hparams.dim, training=training) |
| 269 | + return pre_logits |
| 270 | + |
| 271 | + |
| 272 | +@registry.register_model |
| 273 | +class Revnet104(t2t_model.T2TModel): |
| 274 | + |
| 275 | + def body(self, features): |
| 276 | + return revnet104(features['inputs'], self.hparams) |
| 277 | + |
| 278 | + |
| 279 | +@registry.register_hparams |
| 280 | +def revnet_base(): |
| 281 | + """Set of hyperparameters.""" |
| 282 | + hparams = common_hparams.basic_params1() |
| 283 | + hparams.add_hparam('num_channels_first', [64, 128, 256, 416]) |
| 284 | + hparams.add_hparam('num_channels_second', [256, 512, 1024, 1664]) |
| 285 | + hparams.add_hparam('num_layers_per_block', [1, 1, 10, 1]) |
| 286 | + hparams.add_hparam('first_batch_norm', [False, True, True, True]) |
| 287 | + hparams.add_hparam('strides', [1, 2, 2, 2]) |
| 288 | + hparams.add_hparam('num_channels_init_block', 32) |
| 289 | + hparams.add_hparam('dim', '2d') |
| 290 | + |
| 291 | + hparams.optimizer = 'Momentum' |
| 292 | + hparams.learning_rate = 0.01 |
| 293 | + hparams.weight_decay = 1e-4 |
| 294 | + # Can run with a batch size of 128 with Problem ImageImagenet224 |
| 295 | + hparams.tpu_batch_size_per_shard = 128 |
| 296 | + return hparams |
0 commit comments