Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions backends/xnnpack/_passes/fuse_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_param_tensor,
get_tensor_name,
is_param_node,
sanitize_node_name,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -208,13 +209,13 @@ def _fuse_ops(
# Otherwise, this is a linear node.
fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args)

fused_weight_name = (input_node_weight_name + "_fused_bn").replace(".", "_")
fused_weight_name = sanitize_node_name(input_node_weight_name + "_fused_bn")
if input_node_bias_name == "":
fused_bias_name = (input_node_weight_name + "_bias_fused_bn").replace(
".", "_"
fused_bias_name = sanitize_node_name(
input_node_weight_name + "_bias_fused_bn"
)
else:
fused_bias_name = (input_node_bias_name + "_fused_bn").replace(".", "_")
fused_bias_name = sanitize_node_name(input_node_bias_name + "_fused_bn")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we are sensitive to names, we should put assert somewhere downstream where we check not INVALID_CHARS sneaks in..


# Modify the graph by updating the weight and bias of the conv or linear op
# with the fused weight and bias params, and replacing all the users
Expand Down
29 changes: 29 additions & 0 deletions backends/xnnpack/test/models/regnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
import torchvision
from executorch.backends.xnnpack.test.tester import Tester


class TestInceptionV4(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()

regnet = torchvision.models.regnet_y_32gf()
model_inputs = (torch.randn(3, 299, 299).unsqueeze(0),)

def test_fp32_regnet(self):
(
Tester(self.regnet, self.model_inputs)
.export()
.to_edge_transform_and_lower()
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
12 changes: 12 additions & 0 deletions backends/xnnpack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,15 @@ def is_depthwise_conv(
return (
group_input_channels == 1 and group_output_channels % group_input_channels == 0
)


def sanitize_node_name(name: str) -> str:
"""
Modify a (generated) node name to replace invalid characters (. and -) with underscores.
"""
INVALID_CHARS = [".", "-"]

sanitized = name
for c in INVALID_CHARS:
sanitized = sanitized.replace(c, "_")
return sanitized
Comment on lines +228 to +233
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more performant

return name.translate(str.maketrans('.-', '__'))

Loading