Skip to content

Commit 3f7edee

Browse files
author
ssjia
committed
[ET-VK][ez] Always partition batch norm as it will be fused
The batch norm operator registration had a check_batch_norm_node guard that restricted partitioning to 4D input tensors only. Since batch norm is always fused with adjacent operations during graph compilation, this restriction is unnecessary and prevents valid models from being partitioned to the Vulkan backend. Remove the guard so batch norm is always eligible for Vulkan partitioning regardless of input dimensionality. Differential Revision: [D93511630](https://our.internmc.facebook.com/intern/diff/D93511630/) ghstack-source-id: 341964638 Pull Request resolved: #17508
1 parent 15f604e commit 3f7edee

File tree

1 file changed

+0
-14
lines changed

1 file changed

+0
-14
lines changed

backends/vulkan/op_registry.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,25 +1233,11 @@ def register_embedding():
12331233

12341234
@update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default)
12351235
def register_native_batch_norm_legit_no_training():
1236-
def check_batch_norm_node(node: torch.fx.Node) -> bool:
1237-
x = node.args[0]
1238-
if not isinstance(x, torch.fx.Node):
1239-
return False
1240-
x_val = x.meta.get("val", None)
1241-
if x_val is None:
1242-
return False
1243-
x_shape = x_val.size()
1244-
# Only support 4-D input tensors since this is a restriction enforced by the
1245-
# operator implementation.
1246-
# TODO(ssjia): Add shape agnostic support for batch norm
1247-
return len(x_shape) == 4
1248-
12491236
return OpFeatures(
12501237
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
12511238
inputs_dtypes=utils.FP_T,
12521239
supports_prepacking=True,
12531240
supports_resize=True,
1254-
are_node_inputs_supported_fn=check_batch_norm_node,
12551241
)
12561242

12571243

0 commit comments

Comments
 (0)