File tree Expand file tree Collapse file tree 2 files changed +1
-15
lines changed
Expand file tree Collapse file tree 2 files changed +1
-15
lines changed Original file line number Diff line number Diff line change @@ -1233,25 +1233,11 @@ def register_embedding():
12331233
12341234@update_features (exir_ops .edge .aten ._native_batch_norm_legit_no_training .default )
12351235def register_native_batch_norm_legit_no_training ():
1236- def check_batch_norm_node (node : torch .fx .Node ) -> bool :
1237- x = node .args [0 ]
1238- if not isinstance (x , torch .fx .Node ):
1239- return False
1240- x_val = x .meta .get ("val" , None )
1241- if x_val is None :
1242- return False
1243- x_shape = x_val .size ()
1244- # Only support 4-D input tensors since this is a restriction enforced by the
1245- # operator implementation.
1246- # TODO(ssjia): Add shape agnostic support for batch norm
1247- return len (x_shape ) == 4
1248-
12491236 return OpFeatures (
12501237 inputs_storage = utils .CHANNELS_PACKED_TEXTURE ,
12511238 inputs_dtypes = utils .FP_T ,
12521239 supports_prepacking = True ,
12531240 supports_resize = True ,
1254- are_node_inputs_supported_fn = check_batch_norm_node ,
12551241 )
12561242
12571243
Original file line number Diff line number Diff line change @@ -162,10 +162,10 @@ def preprocess( # noqa: C901
162162 program = apply_passes (
163163 program ,
164164 [
165+ AddmmToLinearTransform (),
165166 FuseBatchNormPass (program ),
166167 FusePatternsPass (),
167168 FuseClampPass (),
168- AddmmToLinearTransform (),
169169 RemoveRedundantOpsTransform (),
170170 FuseQuantizedOpsTransform (),
171171 FoldQDQPass (),
You can’t perform that action at this time.
0 commit comments