@@ -2826,6 +2826,88 @@ def check_nn_module(node):
2826
2826
if node .name == "mul" :
2827
2827
check_nn_module (node )
2828
2828
2829
+ def test_quantize_in_place_ops (self ):
2830
+ class TestQuantizer (Quantizer ):
2831
+ example_inputs = None
2832
+
2833
+ def set_example_inputs (self , example_inputs ):
2834
+ self .example_inputs = example_inputs
2835
+
2836
+ def transform_for_annotation (
2837
+ self , model : torch .fx .GraphModule
2838
+ ) -> torch .fx .GraphModule :
2839
+ # Make a copy of the graph to ensure that we are using the
2840
+ # return value of this function.
2841
+ ep = torch .export .export (model , self .example_inputs )
2842
+ ep = ep .run_decompositions ({})
2843
+ return ep .module ()
2844
+
2845
+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2846
+ act_qspec = QuantizationSpec (
2847
+ dtype = torch .uint8 ,
2848
+ quant_min = 0 ,
2849
+ quant_max = 255 ,
2850
+ qscheme = torch .per_tensor_affine ,
2851
+ is_dynamic = False ,
2852
+ observer_or_fake_quant_ctr = observer .default_observer
2853
+ )
2854
+ for node in model .graph .nodes :
2855
+ if (
2856
+ node .op == "call_function"
2857
+ and node .target == torch .ops .aten .add .Tensor
2858
+ ):
2859
+ input_act0 = node .args [0 ]
2860
+ assert isinstance (input_act0 , torch .fx .Node )
2861
+ input_act1 = node .args [1 ]
2862
+ assert isinstance (input_act1 , torch .fx .Node )
2863
+ print ("input_act1 is a node" )
2864
+ node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2865
+ input_qspec_map = {
2866
+ input_act0 : act_qspec ,
2867
+ input_act1 : act_qspec ,
2868
+ },
2869
+ output_qspec = act_qspec ,
2870
+ _annotated = True ,
2871
+ )
2872
+
2873
+ def validate (self , model : torch .fx .GraphModule ) -> None :
2874
+ pass
2875
+
2876
+ class M (torch .nn .Module ):
2877
+ def __init__ (self ):
2878
+ super ().__init__ ()
2879
+ self .register_buffer ("buf" , torch .randn (1 , 2 , 3 , 3 ))
2880
+
2881
+ def forward (self , x ):
2882
+ self .buf .add_ (x )
2883
+ return self .buf
2884
+
2885
+ def has_inplace_ops (graph_module : torch .fx .GraphModule ) -> bool :
2886
+ return len ([
2887
+ n for n in graph_module .graph .nodes if n .op == "call_function" and n .name .endswith ("_" ) and n .name != "copy_"
2888
+ ]) > 0
2889
+
2890
+ m = M ().eval ()
2891
+ quantizer = TestQuantizer ()
2892
+ example_inputs = (torch .randn (1 , 2 , 3 , 3 ),)
2893
+ quantizer .set_example_inputs (example_inputs )
2894
+ m = export_for_training (m , example_inputs , strict = True ).module ()
2895
+ # Check that the model has in-place ops
2896
+ self .assertTrue (has_inplace_ops (m ))
2897
+ m = prepare_pt2e (m , quantizer )
2898
+ # Check that the model no longer has in-place ops because the graph is funtionalized during annotate_to_tranform
2899
+ self .assertFalse (has_inplace_ops (m ))
2900
+ m (* example_inputs )
2901
+ m = convert_pt2e (m , fold_quantize = True )
2902
+ for node in m .graph .nodes :
2903
+ if node .name == "quantize_per_tensor_default" :
2904
+ # Ensure the quant node is not fused with the mutable buffer
2905
+ self .assertTrue (node .op == "call_function" )
2906
+
2907
+ # Verify the quantized model works
2908
+ result = m (* example_inputs )
2909
+ self .assertIsNotNone (result )
2910
+
2829
2911
2830
2912
@skipIfNoQNNPACK
2831
2913
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_7 , "Requires torch 2.7+" )
0 commit comments