diff --git a/export_to_onnx.py b/export_to_onnx.py new file mode 100644 index 000000000..2a4bbba50 --- /dev/null +++ b/export_to_onnx.py @@ -0,0 +1,43 @@ +import argparse +import io +import torch +from torch.autograd import Variable +import onnx + +from ssd import build_ssd + + +def assertONNXExpected(binary_pb): + model_def = onnx.ModelProto.FromString(binary_pb) + onnx.helper.strip_doc_string(model_def) + return model_def + + +def export_to_string(model, inputs, version=None): + f = io.BytesIO() + with torch.no_grad(): + torch.onnx.export(model, inputs, f, export_params=True, opset_version=version) + return f.getvalue() + + +def save_model(model, input, output): + onnx_model_pb = export_to_string(model, input) + model_def = assertONNXExpected(onnx_model_pb) + with open(output, 'wb') as file: + file.write(model_def.SerializeToString()) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('Export trained model to ONNX format') + parser.add_argument('--model', required=True, help='Path to saved PyTorch network weights (*.pth)') + parser.add_argument('--output', default='ssd.onnx', help='Name of ouput file') + parser.add_argument('--size', default=300, help='Input resolution') + parser.add_argument('--num_classes', default=21, help='Number of trained classes + 1 for background') + args = parser.parse_args() + + net = build_ssd('export', args.size, args.num_classes) + net.load_state_dict(torch.load(args.model, map_location='cpu')) + net.eval() + + input = Variable(torch.randn(1, 3, args.size, args.size)) + save_model(net, input, args.output) diff --git a/layers/functions/detection.py b/layers/functions/detection.py index acdcef065..7332162c1 100644 --- a/layers/functions/detection.py +++ b/layers/functions/detection.py @@ -10,18 +10,20 @@ class Detect(Function): scores and threshold to a top_k number of output predictions for both confidence score and locations. """ - def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): - self.num_classes = num_classes - self.background_label = bkg_label - self.top_k = top_k - # Parameters used in nms. - self.nms_thresh = nms_thresh - if nms_thresh <= 0: - raise ValueError('nms_threshold must be non negative.') - self.conf_thresh = conf_thresh - self.variance = cfg['variance'] + @staticmethod + def symbolic(g, loc_data, conf_data, prior_data, num_classes, top_k, variance, conf_thresh, nms_thresh, phase): + return g.op('DetectionOutput', loc_data, conf_data, prior_data, + num_classes_i=num_classes, + top_k_i=top_k, + keep_top_k_i=top_k, + confidence_threshold_f=conf_thresh, + nms_threshold_f=nms_thresh, + share_location_i=1, + variance_encoded_in_target_i=0, + code_type_s='CENTER_SIZE', + background_label_id_i=0) - def forward(self, loc_data, conf_data, prior_data): + def forward(self, loc_data, conf_data, prior_data, num_classes, top_k, variance, conf_thresh, nms_thresh, phase): """ Args: loc_data: (tensor) Loc preds from loc layers @@ -31,32 +33,39 @@ def forward(self, loc_data, conf_data, prior_data): prior_data: (tensor) Prior boxes and variances from priorbox layers Shape: [1,num_priors,4] """ + loc_data = loc_data.view(loc_data.shape[0], -1, 4) + + if phase == 'export': + prior_data = prior_data.view(-1, 4) + # Ignore variance from priors data + prior_data = prior_data[:prior_data.shape[0] // 2] + num = loc_data.size(0) # batch size num_priors = prior_data.size(0) - output = torch.zeros(num, self.num_classes, self.top_k, 5) + output = torch.zeros(num, num_classes, top_k, 5) conf_preds = conf_data.view(num, num_priors, - self.num_classes).transpose(2, 1) + num_classes).transpose(2, 1) # Decode predictions into bboxes. for i in range(num): - decoded_boxes = decode(loc_data[i], prior_data, self.variance) + decoded_boxes = decode(loc_data[i], prior_data, variance) # For each class, perform nms conf_scores = conf_preds[i].clone() - for cl in range(1, self.num_classes): - c_mask = conf_scores[cl].gt(self.conf_thresh) + for cl in range(1, num_classes): + c_mask = conf_scores[cl].gt(conf_thresh) scores = conf_scores[cl][c_mask] if scores.size(0) == 0: continue l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) boxes = decoded_boxes[l_mask].view(-1, 4) # idx of highest scoring and non-overlapping boxes per class - ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) + ids, count = nms(boxes, scores, nms_thresh, top_k) output[i, cl, :count] = \ torch.cat((scores[ids[:count]].unsqueeze(1), boxes[ids[:count]]), 1) flt = output.contiguous().view(num, -1, 5) _, idx = flt[:, :, 0].sort(1, descending=True) _, rank = idx.sort(1) - flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) + flt[(rank < top_k).unsqueeze(-1).expand_as(flt)].fill_(0) return output diff --git a/layers/functions/prior_box.py b/layers/functions/prior_box.py index 7848a390d..2e06ef34c 100644 --- a/layers/functions/prior_box.py +++ b/layers/functions/prior_box.py @@ -8,12 +8,12 @@ class PriorBox(object): """Compute priorbox coordinates in center-offset form for each source feature map. """ - def __init__(self, cfg): + def __init__(self, cfg, phase): super(PriorBox, self).__init__() self.image_size = cfg['min_dim'] # number of priors for feature map location (either 4 or 6) self.num_priors = len(cfg['aspect_ratios']) - self.variance = cfg['variance'] or [0.1] + self.variance = cfg['variance'] or [0.1, 0.1] self.feature_maps = cfg['feature_maps'] self.min_sizes = cfg['min_sizes'] self.max_sizes = cfg['max_sizes'] @@ -21,6 +21,7 @@ def __init__(self, cfg): self.aspect_ratios = cfg['aspect_ratios'] self.clip = cfg['clip'] self.version = cfg['name'] + self.phase = phase for v in self.variance: if v <= 0: raise ValueError('Variances must be greater than 0') @@ -52,4 +53,16 @@ def forward(self): output = torch.Tensor(mean).view(-1, 4) if self.clip: output.clamp_(max=1, min=0) + if self.phase == 'export': + # CENTER based to CORNER based representaion + w, h = output[:,2], output[:,3] + output[:,0] -= w * 0.5 + output[:,1] -= h * 0.5 + output[:,2] = output[:,0] + w + output[:,3] = output[:,1] + h + + # Append variance after prior boxes like in Caffe + variance = torch.Tensor([self.variance[0], self.variance[0], self.variance[1], self.variance[1]]) \ + .repeat(output.shape[0]).view(-1, 4) + return torch.cat([output, variance], 0).view(1, 2, -1) return output diff --git a/layers/modules/l2norm.py b/layers/modules/l2norm.py index 1e1189d5a..80ff45ad7 100644 --- a/layers/modules/l2norm.py +++ b/layers/modules/l2norm.py @@ -20,5 +20,5 @@ def forward(self, x): norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps #x /= norm x = torch.div(x,norm) - out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x + out = self.weight.view(1, -1, 1, 1) * x return out diff --git a/ssd.py b/ssd.py index 80a23d638..9d89a98ac 100644 --- a/ssd.py +++ b/ssd.py @@ -30,7 +30,7 @@ def __init__(self, phase, size, base, extras, head, num_classes): self.phase = phase self.num_classes = num_classes self.cfg = (coco, voc)[num_classes == 21] - self.priorbox = PriorBox(self.cfg) + self.priorbox = PriorBox(self.cfg, phase) self.priors = Variable(self.priorbox.forward(), volatile=True) self.size = size @@ -43,9 +43,11 @@ def __init__(self, phase, size, base, extras, head, num_classes): self.loc = nn.ModuleList(head[0]) self.conf = nn.ModuleList(head[1]) - if phase == 'test': + if phase == 'test' or phase == 'export': self.softmax = nn.Softmax(dim=-1) - self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) + self.top_k = 200 + self.conf_thresh = 0.01 + self.nms_thresh = 0.45 def forward(self, x): """Applies network layers and ops on input image(s) x. @@ -95,12 +97,18 @@ def forward(self, x): loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) - if self.phase == "test": - output = self.detect( - loc.view(loc.size(0), -1, 4), # loc preds - self.softmax(conf.view(conf.size(0), -1, - self.num_classes)), # conf preds - self.priors.type(type(x.data)) # default boxes + if self.phase == "test" or self.phase == "export": + output = Detect.apply( + loc.view(loc.size(0), -1), # loc preds + self.softmax(conf.view(conf.size(0), -1, self.num_classes)) + .view(conf.size(0), -1), # conf preds + self.priors.type(type(x.data)), # default boxes + self.num_classes, + self.top_k, + self.cfg['variance'], + self.conf_thresh, + self.nms_thresh, + self.phase ) else: output = ( @@ -196,7 +204,7 @@ def multibox(vgg, extra_layers, cfg, num_classes): def build_ssd(phase, size=300, num_classes=21): - if phase != "test" and phase != "train": + if phase != "test" and phase != "train" and phase != "export": print("ERROR: Phase: " + phase + " not recognized") return if size != 300: