@@ -70,16 +70,26 @@ def write(self, frame):
70
70
def release (self ):
71
71
self .cv_sink .release ()
72
72
73
+ class ImshowOutputBackend (OutputBackend ):
74
+ def init (self , sink , fps , width , height ):
75
+ ...
76
+ def write (self , frame ):
77
+ cv2 .imshow ("OVMS StreamClient" , frame )
78
+ cv2 .waitKey (1 )
79
+ def release (self ):
80
+ cv2 .destroyAllWindows ()
81
+
73
82
class StreamClient :
74
83
class OutputBackends ():
75
84
ffmpeg = FfmpegOutputBackend ()
76
85
cv2 = CvOutputBackend ()
86
+ imshow = ImshowOutputBackend ()
77
87
none = OutputBackend ()
78
88
class Datatypes ():
79
89
fp32 = FP32 ()
80
90
uint8 = UINT8 ()
81
91
82
- def __init__ (self , * , preprocess_callback = None , postprocess_callback , source , sink : str , ffmpeg_output_width = None , ffmpeg_output_height = None , output_backend : OutputBackend = OutputBackends .ffmpeg , verbose : bool = False , exact : bool = True , benchmark : bool = False ):
92
+ def __init__ (self , * , preprocess_callback = None , postprocess_callback , source , sink : str , ffmpeg_output_width = None , ffmpeg_output_height = None , output_backend : OutputBackend = OutputBackends .ffmpeg , verbose : bool = False , exact : bool = True , benchmark : bool = False , max_inflight_packets : int = 4 ):
83
93
"""
84
94
Parameters
85
95
----------
@@ -114,6 +124,7 @@ def __init__(self, *, preprocess_callback = None, postprocess_callback, source,
114
124
self .benchmark = benchmark
115
125
116
126
self .pq = queue .PriorityQueue ()
127
+ self .req_q = queue .Queue (max_inflight_packets )
117
128
118
129
def grab_frame (self ):
119
130
success , frame = self .cap .read ()
@@ -132,18 +143,24 @@ def grab_frame(self):
132
143
dropped_frames = 0
133
144
frames = 0
134
145
def callback (self , frame , i , timestamp , result , error ):
146
+ if error is not None :
147
+ if self .benchmark :
148
+ self .dropped_frames += 1
149
+ if self .verbose :
150
+ print (error )
151
+ if i == None :
152
+ i = result .get_response ().parameters ["OVMS_MP_TIMESTAMP" ].int64_param
153
+ if timestamp == None :
154
+ timestamp = result .get_response ().parameters ["OVMS_MP_TIMESTAMP" ].int64_param
135
155
frame = self .postprocess_callback (frame , result )
136
156
self .pq .put ((i , frame , timestamp ))
137
- if error is not None and self .verbose == True :
138
- print (error )
157
+ self .req_q .get ()
139
158
140
159
def display (self ):
141
160
i = 0
142
161
while True :
143
- if self .pq .empty ():
144
- continue
145
162
entry = self .pq .get ()
146
- if (entry [0 ] == i and self .exact ) or (entry [0 ] > i and self .exact is not True ):
163
+ if (entry [0 ] == i and self .exact and self . streaming_api is not True ) or (entry [0 ] > i and ( self .exact is not True or self . streaming_api is True ) ):
147
164
if isinstance (entry [1 ], str ) and entry [1 ] == "EOS" :
148
165
break
149
166
frame = entry [1 ]
@@ -161,8 +178,10 @@ def display(self):
161
178
elif self .exact :
162
179
self .pq .put (entry )
163
180
181
+ def get_timestamp (self ) -> int :
182
+ return int (cv2 .getTickCount () / cv2 .getTickFrequency () * 1e6 )
164
183
165
- def start (self , * , ovms_address : str , input_name : str , model_name : str , datatype : Datatype = FP32 (), batch = True , limit_stream_duration : int = 0 , limit_frames : int = 0 ):
184
+ def start (self , * , ovms_address : str , input_name : str , model_name : str , datatype : Datatype = FP32 (), batch = True , limit_stream_duration : int = 0 , limit_frames : int = 0 , streaming_api : bool = False ):
166
185
"""
167
186
Parameters
168
187
----------
@@ -180,12 +199,15 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
180
199
Limits how long client could run
181
200
limit_frames : int
182
201
Limits how many frames should be processed
202
+ streaming_api : bool
203
+ Use experimental streaming endpoint
183
204
"""
184
205
185
- self .cap = cv2 .VideoCapture (self .source , cv2 .CAP_ANY )
206
+ self .cap = cv2 .VideoCapture (int ( self . source ) if len ( self . source ) == 1 and self . source [ 0 ]. isdigit () else self .source , cv2 .CAP_ANY )
186
207
self .cap .set (cv2 .CAP_PROP_BUFFERSIZE , 0 )
187
208
fps = self .cap .get (cv2 .CAP_PROP_FPS )
188
209
triton_client = grpcclient .InferenceServerClient (url = ovms_address , verbose = False )
210
+ self .streaming_api = streaming_api
189
211
190
212
display_th = threading .Thread (target = self .display )
191
213
display_th .start ()
@@ -199,26 +221,37 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
199
221
if self .height is None :
200
222
self .height = np_test_frame .shape [0 ]
201
223
self .output_backend .init (self .sink , fps , self .width , self .height )
224
+
225
+ if streaming_api :
226
+ triton_client .start_stream (partial (self .callback , None , None , None ))
202
227
203
- i = 0
228
+ frame_number = 0
204
229
total_time_start = time .time ()
205
- while not self .force_exit :
206
- timestamp = time .time ()
207
- frame = self .grab_frame ()
208
- if frame is not None :
209
- np_frame = np .array ([frame ], dtype = datatype .dtype ()) if batch else np .array (frame , dtype = datatype .dtype ())
210
- inputs = [grpcclient .InferInput (input_name , np_frame .shape , datatype .string ())]
211
- inputs [0 ].set_data_from_numpy (np_frame )
212
- triton_client .async_infer (
213
- model_name = model_name ,
214
- callback = partial (self .callback , frame , i , timestamp ),
215
- inputs = inputs )
216
- i += 1
217
- if limit_stream_duration > 0 and time .time () - total_time_start > limit_stream_duration :
218
- break
219
- if limit_frames > 0 and i > limit_frames :
220
- break
221
- self .pq .put ((i , "EOS" ))
230
+ try :
231
+ while not self .force_exit :
232
+ self .req_q .put (frame_number )
233
+ timestamp = time .time ()
234
+ frame = self .grab_frame ()
235
+ if frame is not None :
236
+ np_frame = np .array ([frame ], dtype = datatype .dtype ()) if batch else np .array (frame , dtype = datatype .dtype ())
237
+ inputs = [grpcclient .InferInput (input_name , np_frame .shape , datatype .string ())]
238
+ inputs [0 ].set_data_from_numpy (np_frame )
239
+ if streaming_api :
240
+ triton_client .async_stream_infer (model_name = model_name , inputs = inputs , parameters = {"OVMS_MP_TIMESTAMP" :self .get_timestamp ()}, request_id = str (frame_number ))
241
+ else :
242
+ triton_client .async_infer (
243
+ model_name = model_name ,
244
+ callback = partial (self .callback , frame , frame_number , timestamp ),
245
+ inputs = inputs )
246
+ frame_number += 1
247
+ if limit_stream_duration > 0 and time .time () - total_time_start > limit_stream_duration :
248
+ break
249
+ if limit_frames > 0 and frame_number > limit_frames :
250
+ break
251
+ finally :
252
+ self .pq .put ((frame_number , "EOS" ))
253
+ if streaming_api :
254
+ triton_client .stop_stream ()
222
255
sent_all_frames = time .time () - total_time_start
223
256
224
257
@@ -227,4 +260,4 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
227
260
self .output_backend .release ()
228
261
total_time = time .time () - total_time_start
229
262
if self .benchmark :
230
- print (f"{{\" inference_time\" : { sum (self .inference_time )/ i } , \" dropped_frames\" : { self .dropped_frames } , \" frames\" : { self .frames } , \" fps\" : { self .frames / total_time } , \" total_time\" : { total_time } , \" sent_all_frames\" : { sent_all_frames } }}" )
263
+ print (f"{{\" inference_time\" : { sum (self .inference_time )/ frame_number } , \" dropped_frames\" : { self .dropped_frames } , \" frames\" : { self .frames } , \" fps\" : { self .frames / total_time } , \" total_time\" : { total_time } , \" sent_all_frames\" : { sent_all_frames } }}" )
0 commit comments