Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a66cfaf

Browse files
harini-kannanRyan Sepassi
authored andcommitted
Adding RevNet-104 to the Tensor2Tensor library.
PiperOrigin-RevId: 179612703
1 parent a2f1ee9 commit a66cfaf

File tree

3 files changed

+412
-0
lines changed

3 files changed

+412
-0
lines changed

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tensor2tensor.models import multimodel
3535
from tensor2tensor.models import neural_gpu
3636
from tensor2tensor.models import resnet
37+
from tensor2tensor.models import revnet
3738
from tensor2tensor.models import shake_shake
3839
from tensor2tensor.models import slicenet
3940
from tensor2tensor.models import super_lm

tensor2tensor/models/revnet.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
"""Creates a RevNet with the bottleneck residual function.
18+
19+
Implements the following equations described in the RevNet paper:
20+
y1 = x1 + f(x2)
21+
y2 = x2 + g(y1)
22+
23+
However, in practice, the authors use the following equations to downsample
24+
tensors inside a RevNet block:
25+
26+
y1 = h(x1) + f(x2)
27+
y2 = h(x2) + g(y1)
28+
29+
In this case, h is the downsampling function used to change number of channels.
30+
31+
These modified equations are evident in the authors' code online:
32+
https://github.com/renmengye/revnet-public
33+
34+
For reference, the original paper can be found here:
35+
https://arxiv.org/pdf/1707.04585.pdf
36+
"""
37+
38+
# Dependency imports
39+
40+
from tensor2tensor.layers import common_hparams
41+
from tensor2tensor.layers import rev_block
42+
from tensor2tensor.utils import registry
43+
from tensor2tensor.utils import t2t_model
44+
45+
import tensorflow as tf
46+
47+
CONFIG = {'2d': {'conv': tf.layers.conv2d,
48+
'max_pool': tf.layers.max_pooling2d,
49+
'avg_pool': tf.layers.average_pooling2d,
50+
'split_axis': 3,
51+
'reduction_dimensions': [1, 2]
52+
},
53+
'3d': {'conv': tf.layers.conv3d,
54+
'max_pool': tf.layers.max_pooling3d,
55+
'avg_pool': tf.layers.average_pooling2d,
56+
'split_axis': 4,
57+
'reduction_dimensions': [1, 2, 3]
58+
}
59+
}
60+
61+
62+
def f(x, depth1, depth2, dim='2d', first_batch_norm=True, layer_stride=1,
63+
training=True, padding='SAME'):
64+
"""Applies bottleneck residual function for 104-layer RevNet.
65+
66+
Args:
67+
x: input tensor
68+
depth1: Number of output channels for the first and second conv layers.
69+
depth2: Number of output channels for the third conv layer.
70+
dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
71+
first_batch_norm: Whether to keep the first batch norm layer or not.
72+
Typically used in the first RevNet block.
73+
layer_stride: Stride for the first conv filter. Note that this particular
74+
104-layer RevNet architecture only varies the stride for the first conv
75+
filter. The stride for the second conv filter is always set to 1.
76+
training: True for train phase, False for eval phase.
77+
padding: Padding for each conv layer.
78+
79+
Returns:
80+
Output tensor after applying residual function for 104-layer RevNet.
81+
"""
82+
conv = CONFIG[dim]['conv']
83+
with tf.variable_scope('f'):
84+
if first_batch_norm:
85+
net = tf.layers.batch_normalization(x, training=training)
86+
net = tf.nn.relu(net)
87+
else:
88+
net = x
89+
net = conv(net, depth1, 1, strides=layer_stride,
90+
padding=padding, activation=None)
91+
92+
net = tf.layers.batch_normalization(net, training=training)
93+
net = tf.nn.relu(net)
94+
net = conv(net, depth1, 3, strides=1,
95+
padding=padding, activation=None)
96+
97+
net = tf.layers.batch_normalization(net, training=training)
98+
net = tf.nn.relu(net)
99+
net = conv(net, depth2, 1, strides=1,
100+
padding=padding, activation=None)
101+
return net
102+
103+
104+
def h(x, output_channels, dim='2d', layer_stride=1, scope='h'):
105+
"""Downsamples 'x' using a 1x1 convolution filter and a chosen stride.
106+
107+
Args:
108+
x: input tensor of size [N, H, W, C]
109+
output_channels: Desired number of output channels.
110+
dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
111+
layer_stride: What stride to use. Usually 1 or 2.
112+
scope: Optional variable scope for the h function.
113+
114+
This function uses a 1x1 convolution filter and a chosen stride to downsample
115+
the input tensor x.
116+
117+
Returns:
118+
A downsampled tensor of size [N, H/2, W/2, output_channels] if layer_stride
119+
is 2, else returns a tensor of size [N, H, W, output_channels] if
120+
layer_stride is 1.
121+
"""
122+
conv = CONFIG[dim]['conv']
123+
with tf.variable_scope(scope):
124+
x = conv(x, output_channels, 1, strides=layer_stride, padding='SAME',
125+
activation=None)
126+
return x
127+
128+
129+
def init(images, num_channels, dim='2d', training=True, scope='init'):
130+
"""Standard ResNet initial block used as first RevNet block.
131+
132+
Args:
133+
images: [N, H, W, 3] tensor of input images to the model.
134+
num_channels: Output depth of convolutional layer in initial block.
135+
dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
136+
training: True for train phase, False for eval phase.
137+
scope: Optional scope for the init block.
138+
139+
Returns:
140+
Two [N, H, W, C] output activations from input images.
141+
"""
142+
conv = CONFIG[dim]['conv']
143+
pool = CONFIG[dim]['max_pool']
144+
with tf.variable_scope(scope):
145+
net = conv(images, num_channels, 7, strides=2,
146+
padding='SAME', activation=None)
147+
net = tf.layers.batch_normalization(net, training=training)
148+
net = tf.nn.relu(net)
149+
net = pool(net, pool_size=3, strides=2)
150+
x1, x2 = tf.split(net, 2, axis=CONFIG[dim]['split_axis'])
151+
return x1, x2
152+
153+
154+
def unit(x1, x2, block_num, depth1, depth2, num_layers, dim='2d',
155+
first_batch_norm=True, stride=1, training=True):
156+
"""Implements bottleneck RevNet unit from authors' RevNet-104 architecture.
157+
158+
Args:
159+
x1: [N, H, W, C] tensor of network activations.
160+
x2: [N, H, W, C] tensor of network activations.
161+
block_num: integer ID of block
162+
depth1: First depth in bottleneck residual unit.
163+
depth2: Second depth in bottleneck residual unit.
164+
num_layers: Number of layers in the RevNet block.
165+
dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
166+
first_batch_norm: Whether to keep the first batch norm layer or not.
167+
Typically used in the first RevNet block.
168+
stride: Stride for the residual function.
169+
training: True for train phase, False for eval phase.
170+
171+
Returns:
172+
Two [N, H, W, C] output activation tensors.
173+
"""
174+
scope_name = 'unit_%d' % block_num
175+
with tf.variable_scope(scope_name):
176+
# Manual implementation of downsampling
177+
with tf.variable_scope('downsampling'):
178+
with tf.variable_scope('x1'):
179+
hx1 = h(x1, depth2, dim=dim, layer_stride=stride)
180+
fx2 = f(x2, depth1, depth2, dim=dim, layer_stride=stride,
181+
first_batch_norm=first_batch_norm, training=training)
182+
x1 = hx1 + fx2
183+
with tf.variable_scope('x2'):
184+
hx2 = h(x2, depth2, dim=dim, layer_stride=stride)
185+
fx1 = f(x1, depth1, depth2, dim=dim, training=training)
186+
x2 = hx2 + fx1
187+
188+
# Full block using memory-efficient rev_block implementation.
189+
with tf.variable_scope('full_block'):
190+
residual_func = lambda x: f(x, depth1, depth2, dim=dim, training=training)
191+
x1, x2 = rev_block.rev_block(x1, x2,
192+
residual_func,
193+
residual_func,
194+
num_layers=num_layers)
195+
return x1, x2
196+
197+
198+
def final_block(x1, x2, dim='2d', training=True, scope='final_block'):
199+
"""Converts activations from last RevNet block to pre-logits.
200+
201+
Args:
202+
x1: [NxHxWxC] tensor of network activations.
203+
x2: [NxHxWxC] tensor of network activations.
204+
dim: '2d' if 2-dimensional, '3d' if 3-dimensional.
205+
training: True for train phase, False for eval phase.
206+
scope: Optional variable scope for the final block.
207+
208+
Returns:
209+
[N, hidden_dim] pre-logits tensor from activations x1 and x2.
210+
"""
211+
212+
# Final batch norm and relu
213+
with tf.variable_scope(scope):
214+
y = tf.concat([x1, x2], axis=CONFIG[dim]['split_axis'])
215+
y = tf.layers.batch_normalization(y, training=training)
216+
y = tf.nn.relu(y)
217+
218+
# Global average pooling
219+
net = tf.reduce_mean(y, CONFIG[dim]['reduction_dimensions'],
220+
name='final_pool', keep_dims=True)
221+
222+
return net
223+
224+
225+
def revnet104(inputs, hparams, reuse=None):
226+
"""Uses Tensor2Tensor memory optimized RevNet block to build a RevNet.
227+
228+
Args:
229+
inputs: [NxHxWx3] tensor of input images to the model.
230+
hparams: HParams object that contains the following parameters,
231+
in addition to the parameters contained in the basic_params1() object in
232+
the common_hparams module:
233+
num_channels_first - A Python list where each element represents the
234+
depth of the first and third convolutional layers in the bottleneck
235+
residual unit for a given block.
236+
num_channels_second - A Python list where each element represents the
237+
depth of the second convolutional layer in the bottleneck residual
238+
unit for a given block.
239+
num_layers_per_block - A Python list containing the number of RevNet
240+
layers for each block.
241+
first_batch_norm - A Python list containing booleans representing the
242+
presence of a batch norm layer at the beginning of a given block.
243+
strides - A Python list containing integers representing the stride of
244+
the residual function for each block.
245+
num_channels_init_block - An integer representing the number of channels
246+
for the convolutional layer in the initial block.
247+
dimension - A string (either "2d" or "3d") that decides if the RevNet is
248+
2-dimensional or 3-dimensional.
249+
reuse: Whether to reuse the default variable scope.
250+
251+
Returns:
252+
[batch_size, hidden_dim] pre-logits tensor from the bottleneck RevNet.
253+
"""
254+
training = hparams.mode == tf.estimator.ModeKeys.TRAIN
255+
with tf.variable_scope('RevNet104', reuse=reuse):
256+
x1, x2 = init(inputs,
257+
num_channels=hparams.num_channels_init_block,
258+
dim=hparams.dim,
259+
training=training)
260+
for block_num in range(1, len(hparams.num_layers_per_block)):
261+
block = {'depth1': hparams.num_channels_first[block_num],
262+
'depth2': hparams.num_channels_second[block_num],
263+
'num_layers': hparams.num_layers_per_block[block_num],
264+
'first_batch_norm': hparams.first_batch_norm[block_num],
265+
'stride': hparams.strides[block_num]}
266+
x1, x2 = unit(x1, x2, block_num, dim=hparams.dim, training=training,
267+
**block)
268+
pre_logits = final_block(x1, x2, dim=hparams.dim, training=training)
269+
return pre_logits
270+
271+
272+
@registry.register_model
273+
class Revnet104(t2t_model.T2TModel):
274+
275+
def body(self, features):
276+
return revnet104(features['inputs'], self.hparams)
277+
278+
279+
@registry.register_hparams
280+
def revnet_base():
281+
"""Set of hyperparameters."""
282+
hparams = common_hparams.basic_params1()
283+
hparams.add_hparam('num_channels_first', [64, 128, 256, 416])
284+
hparams.add_hparam('num_channels_second', [256, 512, 1024, 1664])
285+
hparams.add_hparam('num_layers_per_block', [1, 1, 10, 1])
286+
hparams.add_hparam('first_batch_norm', [False, True, True, True])
287+
hparams.add_hparam('strides', [1, 2, 2, 2])
288+
hparams.add_hparam('num_channels_init_block', 32)
289+
hparams.add_hparam('dim', '2d')
290+
291+
hparams.optimizer = 'Momentum'
292+
hparams.learning_rate = 0.01
293+
hparams.weight_decay = 1e-4
294+
# Can run with a batch size of 128 with Problem ImageImagenet224
295+
hparams.tpu_batch_size_per_shard = 128
296+
return hparams

0 commit comments

Comments
 (0)