Skip to content

Commit dc8fc7f

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][ez] Always partition batch norm as it will be fused
Pull Request resolved: #17508 The batch norm operator registration had a check_batch_norm_node guard that restricted partitioning to 4D input tensors only. Since batch norm is always fused with adjacent operations during graph compilation, this restriction is unnecessary and prevents valid models from being partitioned to the Vulkan backend. Remove the guard so batch norm is always eligible for Vulkan partitioning regardless of input dimensionality. ghstack-source-id: 342806074 @exported-using-ghexport Differential Revision: [D93511630](https://our.internmc.facebook.com/intern/diff/D93511630/)
1 parent 4ca676e commit dc8fc7f

File tree

2 files changed

+1
-15
lines changed

2 files changed

+1
-15
lines changed

backends/vulkan/op_registry.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,25 +1233,11 @@ def register_embedding():
12331233

12341234
@update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default)
12351235
def register_native_batch_norm_legit_no_training():
1236-
def check_batch_norm_node(node: torch.fx.Node) -> bool:
1237-
x = node.args[0]
1238-
if not isinstance(x, torch.fx.Node):
1239-
return False
1240-
x_val = x.meta.get("val", None)
1241-
if x_val is None:
1242-
return False
1243-
x_shape = x_val.size()
1244-
# Only support 4-D input tensors since this is a restriction enforced by the
1245-
# operator implementation.
1246-
# TODO(ssjia): Add shape agnostic support for batch norm
1247-
return len(x_shape) == 4
1248-
12491236
return OpFeatures(
12501237
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
12511238
inputs_dtypes=utils.FP_T,
12521239
supports_prepacking=True,
12531240
supports_resize=True,
1254-
are_node_inputs_supported_fn=check_batch_norm_node,
12551241
)
12561242

12571243

backends/vulkan/vulkan_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ def preprocess( # noqa: C901
162162
program = apply_passes(
163163
program,
164164
[
165+
AddmmToLinearTransform(),
165166
FuseBatchNormPass(program),
166167
FusePatternsPass(),
167168
FuseClampPass(),
168-
AddmmToLinearTransform(),
169169
RemoveRedundantOpsTransform(),
170170
FuseQuantizedOpsTransform(),
171171
FoldQDQPass(),

0 commit comments

Comments
 (0)