@@ -181,7 +181,7 @@ def filter_box(self, result, threshold):
181181 filter_res = {'boxes' : boxes , 'boxes_num' : filter_num }
182182 return filter_res
183183
184- def predict (self , repeats = 1 ):
184+ def predict (self , repeats = 1 , run_benchmark = False ):
185185 '''
186186 Args:
187187 repeats (int): repeats number for prediction
@@ -193,6 +193,15 @@ def predict(self, repeats=1):
193193 '''
194194 # model prediction
195195 np_boxes_num , np_boxes , np_masks = np .array ([0 ]), None , None
196+
197+ if run_benchmark :
198+ for i in range (repeats ):
199+ self .predictor .run ()
200+ paddle .device .cuda .synchronize ()
201+ result = dict (
202+ boxes = np_boxes , masks = np_masks , boxes_num = np_boxes_num )
203+ return result
204+
196205 for i in range (repeats ):
197206 self .predictor .run ()
198207 output_names = self .predictor .get_output_names ()
@@ -272,9 +281,9 @@ def predict_image_slice(self,
272281 self .det_times .preprocess_time_s .end ()
273282
274283 # model prediction
275- result = self .predict (repeats = 50 ) # warmup
284+ result = self .predict (repeats = 50 , run_benchmark = True ) # warmup
276285 self .det_times .inference_time_s .start ()
277- result = self .predict (repeats = repeats )
286+ result = self .predict (repeats = repeats , run_benchmark = True )
278287 self .det_times .inference_time_s .end (repeats = repeats )
279288
280289 # postprocess
@@ -370,9 +379,9 @@ def predict_image(self,
370379 self .det_times .preprocess_time_s .end ()
371380
372381 # model prediction
373- result = self .predict (repeats = 50 ) # warmup
382+ result = self .predict (repeats = 50 , run_benchmark = True ) # warmup
374383 self .det_times .inference_time_s .start ()
375- result = self .predict (repeats = repeats )
384+ result = self .predict (repeats = repeats , run_benchmark = True )
376385 self .det_times .inference_time_s .end (repeats = repeats )
377386
378387 # postprocess
@@ -568,7 +577,7 @@ def __init__(
568577 output_dir = output_dir ,
569578 threshold = threshold , )
570579
571- def predict (self , repeats = 1 ):
580+ def predict (self , repeats = 1 , run_benchmark = False ):
572581 '''
573582 Args:
574583 repeats (int): repeat number for prediction
@@ -577,7 +586,20 @@ def predict(self, repeats=1):
577586 'cate_label': label of segm, shape:[N]
578587 'cate_score': confidence score of segm, shape:[N]
579588 '''
580- np_label , np_score , np_segms = None , None , None
589+ np_segms , np_label , np_score , np_boxes_num = None , None , None , np .array (
590+ [0 ])
591+
592+ if run_benchmark :
593+ for i in range (repeats ):
594+ self .predictor .run ()
595+ paddle .device .cuda .synchronize ()
596+ result = dict (
597+ segm = np_segms ,
598+ label = np_label ,
599+ score = np_score ,
600+ boxes_num = np_boxes_num )
601+ return result
602+
581603 for i in range (repeats ):
582604 self .predictor .run ()
583605 output_names = self .predictor .get_output_names ()
@@ -659,7 +681,7 @@ def postprocess(self, inputs, result):
659681 result = dict (boxes = np_boxes , boxes_num = np_boxes_num )
660682 return result
661683
662- def predict (self , repeats = 1 ):
684+ def predict (self , repeats = 1 , run_benchmark = False ):
663685 '''
664686 Args:
665687 repeats (int): repeat number for prediction
@@ -668,6 +690,14 @@ def predict(self, repeats=1):
668690 matix element:[class, score, x_min, y_min, x_max, y_max]
669691 '''
670692 np_score_list , np_boxes_list = [], []
693+
694+ if run_benchmark :
695+ for i in range (repeats ):
696+ self .predictor .run ()
697+ paddle .device .cuda .synchronize ()
698+ result = dict (boxes = np_score_list , boxes_num = np_boxes_list )
699+ return result
700+
671701 for i in range (repeats ):
672702 self .predictor .run ()
673703 np_score_list .clear ()
0 commit comments