Skip to content

Commit 8ef5417

Browse files
authored
[ET-VK][ez] Always partition batch norm as it will be fused
Differential Revision: D93511630 Pull Request resolved: #17508
1 parent 7e1ead7 commit 8ef5417

File tree

2 files changed

+1
-15
lines changed

2 files changed

+1
-15
lines changed

backends/vulkan/op_registry.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,25 +1233,11 @@ def register_embedding():
12331233

12341234
@update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default)
12351235
def register_native_batch_norm_legit_no_training():
1236-
def check_batch_norm_node(node: torch.fx.Node) -> bool:
1237-
x = node.args[0]
1238-
if not isinstance(x, torch.fx.Node):
1239-
return False
1240-
x_val = x.meta.get("val", None)
1241-
if x_val is None:
1242-
return False
1243-
x_shape = x_val.size()
1244-
# Only support 4-D input tensors since this is a restriction enforced by the
1245-
# operator implementation.
1246-
# TODO(ssjia): Add shape agnostic support for batch norm
1247-
return len(x_shape) == 4
1248-
12491236
return OpFeatures(
12501237
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
12511238
inputs_dtypes=utils.FP_T,
12521239
supports_prepacking=True,
12531240
supports_resize=True,
1254-
are_node_inputs_supported_fn=check_batch_norm_node,
12551241
)
12561242

12571243

backends/vulkan/vulkan_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ def preprocess( # noqa: C901
162162
program = apply_passes(
163163
program,
164164
[
165+
AddmmToLinearTransform(),
165166
FuseBatchNormPass(program),
166167
FusePatternsPass(),
167168
FuseClampPass(),
168-
AddmmToLinearTransform(),
169169
RemoveRedundantOpsTransform(),
170170
FuseQuantizedOpsTransform(),
171171
FoldQDQPass(),

0 commit comments

Comments
 (0)