Skip to content

Commit 37b4a25

Browse files
committed
play gn
1 parent dd925ba commit 37b4a25

File tree

3 files changed

+192
-3
lines changed

3 files changed

+192
-3
lines changed

example/mpii.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def main(args):
4141

4242
# create model
4343
print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks))
44-
model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes)
44+
model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks,
45+
num_classes=args.num_classes)
4546

4647
model = torch.nn.DataParallel(model).to(device)
4748

@@ -297,7 +298,7 @@ def validate(val_loader, model, criterion, num_classes, debug=False, flip=True):
297298
choices=model_names,
298299
help='model architecture: ' +
299300
' | '.join(model_names) +
300-
' (default: resnet18)')
301+
' (default: hg)')
301302
parser.add_argument('-s', '--stacks', default=8, type=int, metavar='N',
302303
help='Number of hourglasses to stack')
303304
parser.add_argument('--features', default=256, type=int, metavar='N',

pose/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .hourglass import *
2-
from .preresnet import *
2+
from .hourglass_gn import *
3+
from .preresnet import *

pose/models/hourglass_gn.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
'''
2+
Hourglass network inserted in the pre-activated Resnet
3+
Use lr=0.01 for current version
4+
(c) YANG, Wei
5+
'''
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
9+
# from .preresnet import BasicBlock, Bottleneck
10+
11+
12+
__all__ = ['hg_gn']
13+
14+
# hardcode group number
15+
gn = 32
16+
17+
class Bottleneck(nn.Module):
18+
expansion = 2
19+
20+
def __init__(self, inplanes, planes, stride=1, downsample=None):
21+
super(Bottleneck, self).__init__()
22+
23+
self.bn1 = nn.GroupNorm(gn, inplanes)
24+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True)
25+
self.bn2 = nn.GroupNorm(gn, planes)
26+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
27+
padding=1, bias=True)
28+
self.bn3 = nn.GroupNorm(gn, planes)
29+
self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True)
30+
self.relu = nn.ReLU(inplace=True)
31+
self.downsample = downsample
32+
self.stride = stride
33+
34+
def forward(self, x):
35+
residual = x
36+
37+
out = self.bn1(x)
38+
out = self.relu(out)
39+
out = self.conv1(out)
40+
41+
out = self.bn2(out)
42+
out = self.relu(out)
43+
out = self.conv2(out)
44+
45+
out = self.bn3(out)
46+
out = self.relu(out)
47+
out = self.conv3(out)
48+
49+
if self.downsample is not None:
50+
residual = self.downsample(x)
51+
52+
out += residual
53+
54+
return out
55+
56+
57+
class Hourglass(nn.Module):
58+
def __init__(self, block, num_blocks, planes, depth):
59+
super(Hourglass, self).__init__()
60+
self.depth = depth
61+
self.block = block
62+
self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
63+
64+
def _make_residual(self, block, num_blocks, planes):
65+
layers = []
66+
for i in range(0, num_blocks):
67+
layers.append(block(planes*block.expansion, planes))
68+
return nn.Sequential(*layers)
69+
70+
def _make_hour_glass(self, block, num_blocks, planes, depth):
71+
hg = []
72+
for i in range(depth):
73+
res = []
74+
for j in range(3):
75+
res.append(self._make_residual(block, num_blocks, planes))
76+
if i == 0:
77+
res.append(self._make_residual(block, num_blocks, planes))
78+
hg.append(nn.ModuleList(res))
79+
return nn.ModuleList(hg)
80+
81+
def _hour_glass_forward(self, n, x):
82+
up1 = self.hg[n-1][0](x)
83+
low1 = F.max_pool2d(x, 2, stride=2)
84+
low1 = self.hg[n-1][1](low1)
85+
86+
if n > 1:
87+
low2 = self._hour_glass_forward(n-1, low1)
88+
else:
89+
low2 = self.hg[n-1][3](low1)
90+
low3 = self.hg[n-1][2](low2)
91+
up2 = F.interpolate(low3, scale_factor=2)
92+
out = up1 + up2
93+
return out
94+
95+
def forward(self, x):
96+
return self._hour_glass_forward(self.depth, x)
97+
98+
99+
class HourglassNet(nn.Module):
100+
'''Hourglass model from Newell et al ECCV 2016'''
101+
def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=16):
102+
super(HourglassNet, self).__init__()
103+
104+
self.inplanes = 64
105+
self.num_feats = 128
106+
self.num_stacks = num_stacks
107+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
108+
bias=True)
109+
self.bn1 = nn.GroupNorm(gn, self.inplanes)
110+
self.relu = nn.ReLU(inplace=True)
111+
self.layer1 = self._make_residual(block, self.inplanes, 1)
112+
self.layer2 = self._make_residual(block, self.inplanes, 1)
113+
self.layer3 = self._make_residual(block, self.num_feats, 1)
114+
self.maxpool = nn.MaxPool2d(2, stride=2)
115+
116+
# build hourglass modules
117+
ch = self.num_feats*block.expansion
118+
hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
119+
for i in range(num_stacks):
120+
hg.append(Hourglass(block, num_blocks, self.num_feats, 4))
121+
res.append(self._make_residual(block, self.num_feats, num_blocks))
122+
fc.append(self._make_fc(ch, ch))
123+
score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True))
124+
if i < num_stacks-1:
125+
fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True))
126+
score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True))
127+
self.hg = nn.ModuleList(hg)
128+
self.res = nn.ModuleList(res)
129+
self.fc = nn.ModuleList(fc)
130+
self.score = nn.ModuleList(score)
131+
self.fc_ = nn.ModuleList(fc_)
132+
self.score_ = nn.ModuleList(score_)
133+
134+
def _make_residual(self, block, planes, blocks, stride=1):
135+
downsample = None
136+
if stride != 1 or self.inplanes != planes * block.expansion:
137+
downsample = nn.Sequential(
138+
nn.Conv2d(self.inplanes, planes * block.expansion,
139+
kernel_size=1, stride=stride, bias=True),
140+
)
141+
142+
layers = []
143+
layers.append(block(self.inplanes, planes, stride, downsample))
144+
self.inplanes = planes * block.expansion
145+
for i in range(1, blocks):
146+
layers.append(block(self.inplanes, planes))
147+
148+
return nn.Sequential(*layers)
149+
150+
def _make_fc(self, inplanes, outplanes):
151+
bn = nn.GroupNorm(gn, inplanes)
152+
conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True)
153+
return nn.Sequential(
154+
conv,
155+
bn,
156+
self.relu,
157+
)
158+
159+
def forward(self, x):
160+
out = []
161+
x = self.conv1(x)
162+
x = self.bn1(x)
163+
x = self.relu(x)
164+
165+
x = self.layer1(x)
166+
x = self.maxpool(x)
167+
x = self.layer2(x)
168+
x = self.layer3(x)
169+
170+
for i in range(self.num_stacks):
171+
y = self.hg[i](x)
172+
y = self.res[i](y)
173+
y = self.fc[i](y)
174+
score = self.score[i](y)
175+
out.append(score)
176+
if i < self.num_stacks-1:
177+
fc_ = self.fc_[i](y)
178+
score_ = self.score_[i](score)
179+
x = x + fc_ + score_
180+
181+
return out
182+
183+
184+
def hg_gn(**kwargs):
185+
model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'],
186+
num_classes=kwargs['num_classes'])
187+
return model

0 commit comments

Comments
 (0)