From 7c83bc4fb269457aeb40af69ff9678f7aa48eb52 Mon Sep 17 00:00:00 2001 From: Iden Craven Date: Fri, 6 Mar 2020 16:31:11 -0700 Subject: [PATCH] Fix node names ending with _fwd not being able to find weights --- mmdnn/conversion/mxnet/mxnet_parser.py | 36 +++++++++++++++++++------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/mmdnn/conversion/mxnet/mxnet_parser.py b/mmdnn/conversion/mxnet/mxnet_parser.py index be9e28c5..f9d6e94a 100644 --- a/mmdnn/conversion/mxnet/mxnet_parser.py +++ b/mmdnn/conversion/mxnet/mxnet_parser.py @@ -402,10 +402,14 @@ def rename_FullyConnected(self, source_node): # weights if self.weight_loaded: + if source_node.name.endswith("_fwd"): + node_prefix = source_node.name[:-4] + else: + node_prefix = source_node.name if self.data_format == 'NM': - self.set_weight(source_node.name, "weights", self.weight_data.get(source_node.name + "_weight").asnumpy().transpose((1, 0))) + self.set_weight(source_node.name, "weights", self.weight_data.get(node_prefix + "_weight").asnumpy().transpose((1, 0))) else: - weight = self.weight_data.get(source_node.name + "_weight").asnumpy().transpose((1, 0)) + weight = self.weight_data.get(node_prefix + "_weight").asnumpy().transpose((1, 0)) original_shape = weight.shape channel_first_list = self.trace_shape(source_node, IR_node) @@ -417,7 +421,7 @@ def rename_FullyConnected(self, source_node): self.set_weight(source_node.name, "weights", weight) if IR_node.attr["use_bias"].b: - self.set_weight(source_node.name, "bias", self.weight_data.get(source_node.name + "_bias").asnumpy()) + self.set_weight(source_node.name, "bias", self.weight_data.get(node_prefix + "_bias").asnumpy()) if not self.data_format == 'NM': # print("Warning: Layer [{}] has changed model data format from [{}] to [NM]".format(source_node.name, self.data_format)) @@ -502,7 +506,11 @@ def rename_Convolution(self, source_node): # weights if self.weight_loaded: - weight = self.weight_data.get(source_node.name + "_weight").asnumpy() + if source_node.name.endswith("_fwd"): + node_prefix = source_node.name[:-4] + else: + node_prefix = source_node.name + weight = self.weight_data.get(node_prefix + "_weight").asnumpy() if not layout in MXNetParser.channels_last: weight = MXNetParser.transpose(weight, dim) if IR_node.op == "DepthwiseConv": @@ -510,7 +518,7 @@ def rename_Convolution(self, source_node): self.set_weight(source_node.name, "weights", weight) if IR_node.attr["use_bias"].b: - self.set_weight(source_node.name, "bias", self.weight_data.get(source_node.name + "_bias").asnumpy()) + self.set_weight(source_node.name, "bias", self.weight_data.get(node_prefix + "_bias").asnumpy()) def rename_Activation(self, source_node): @@ -537,19 +545,29 @@ def rename_BatchNorm(self, source_node): # weights if self.weight_loaded: + if source_node.name.endswith("_fwd"): + node_prefix = source_node.name[:-4] + else: + node_prefix = source_node.name # gamma if IR_node.attr["scale"].b: - self.set_weight(source_node.name, "scale", self.weight_data.get(source_node.name + "_gamma").asnumpy()) + self.set_weight(source_node.name, "scale", self.weight_data.get(node_prefix + "_gamma").asnumpy()) # beta if IR_node.attr["bias"].b: - self.set_weight(source_node.name, "bias", self.weight_data.get(source_node.name + "_beta").asnumpy()) + self.set_weight(source_node.name, "bias", self.weight_data.get(node_prefix + "_beta").asnumpy()) # mean - self.set_weight(source_node.name, "mean", self.weight_data.get(source_node.name + "_moving_mean").asnumpy()) + try: + self.set_weight(source_node.name, "mean", self.weight_data.get(node_prefix + "_moving_mean").asnumpy()) + except AttributeError: + self.set_weight(source_node.name, "mean", self.weight_data.get(node_prefix + "_running_mean").asnumpy()) # var - self.set_weight(source_node.name, "var", self.weight_data.get(source_node.name + "_moving_var").asnumpy()) + try: + self.set_weight(source_node.name, "var", self.weight_data.get(node_prefix + "_moving_var").asnumpy()) + except AttributeError: + self.set_weight(source_node.name, "var", self.weight_data.get(node_prefix + "_running_var").asnumpy()) def rename_Pooling(self, source_node):