@@ -2826,6 +2826,78 @@ def check_nn_module(node):
2826
2826
if node .name == "mul" :
2827
2827
check_nn_module (node )
2828
2828
2829
+ def test_quantize_in_place_ops (self ):
2830
+ class TestQuantizer (Quantizer ):
2831
+ example_inputs = None
2832
+
2833
+ def set_example_inputs (self , example_inputs ):
2834
+ self .example_inputs = example_inputs
2835
+
2836
+ def transform_for_annotation (
2837
+ self , model : torch .fx .GraphModule
2838
+ ) -> torch .fx .GraphModule :
2839
+ # Make a copy of the graph to ensure that we are using the
2840
+ # return value of this function.
2841
+ ep = torch .export .export (model , self .example_inputs )
2842
+ ep = ep .run_decompositions ({})
2843
+ return ep .module ()
2844
+
2845
+ def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
2846
+ act_qspec = QuantizationSpec (
2847
+ dtype = torch .uint8 ,
2848
+ quant_min = 0 ,
2849
+ quant_max = 255 ,
2850
+ qscheme = torch .per_tensor_affine ,
2851
+ is_dynamic = False ,
2852
+ observer_or_fake_quant_ctr = observer .default_observer
2853
+ )
2854
+ for node in model .graph .nodes :
2855
+ if (
2856
+ node .op == "call_function"
2857
+ and node .target == torch .ops .aten .add .Tensor
2858
+ ):
2859
+ input_act0 = node .args [0 ]
2860
+ assert isinstance (input_act0 , torch .fx .Node )
2861
+ input_act1 = node .args [1 ]
2862
+ if isinstance (input_act1 , torch .fx .Node ):
2863
+ node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2864
+ input_qspec_map = {
2865
+ input_act0 : act_qspec ,
2866
+ input_act1 : act_qspec ,
2867
+ },
2868
+ output_qspec = act_qspec ,
2869
+ _annotated = True ,
2870
+ )
2871
+ else :
2872
+ # Handle case where second input is a constant
2873
+ node .meta ["quantization_annotation" ] = QuantizationAnnotation (
2874
+ input_qspec_map = {
2875
+ input_act0 : act_qspec ,
2876
+ },
2877
+ output_qspec = act_qspec ,
2878
+ _annotated = True ,
2879
+ )
2880
+
2881
+ def validate (self , model : torch .fx .GraphModule ) -> None :
2882
+ pass
2883
+
2884
+ class M (torch .nn .Module ):
2885
+ def forward (self , x ):
2886
+ return x + 3
2887
+
2888
+ m = M ().eval ()
2889
+ quantizer = TestQuantizer ()
2890
+ example_inputs = (torch .randn (1 , 2 , 3 , 3 ),)
2891
+ quantizer .set_example_inputs (example_inputs )
2892
+ m = export_for_training (m , example_inputs , strict = True ).module ()
2893
+ m = prepare_pt2e (m , quantizer )
2894
+ m (* example_inputs )
2895
+ m = convert_pt2e (m )
2896
+
2897
+ # Verify the quantized model works
2898
+ result = m (* example_inputs )
2899
+ self .assertIsNotNone (result )
2900
+
2829
2901
2830
2902
@skipIfNoQNNPACK
2831
2903
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_7 , "Requires torch 2.7+" )
0 commit comments