Skip to content

Commit 9a30c8d

Browse files
bstrzeledtrawins
authored andcommitted
Added streaming endpoint support for streaming client (#2139)
Co-authored-by: Dariusz Trawinski <Dariusz.Trawinski@intel.com>
1 parent e8125c5 commit 9a30c8d

File tree

3 files changed

+62
-47
lines changed

3 files changed

+62
-47
lines changed

demos/common/stream_client/stream_client.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,26 @@ def write(self, frame):
7070
def release(self):
7171
self.cv_sink.release()
7272

73+
class ImshowOutputBackend(OutputBackend):
74+
def init(self, sink, fps, width, height):
75+
...
76+
def write(self, frame):
77+
cv2.imshow("OVMS StreamClient", frame)
78+
cv2.waitKey(1)
79+
def release(self):
80+
cv2.destroyAllWindows()
81+
7382
class StreamClient:
7483
class OutputBackends():
7584
ffmpeg = FfmpegOutputBackend()
7685
cv2 = CvOutputBackend()
86+
imshow = ImshowOutputBackend()
7787
none = OutputBackend()
7888
class Datatypes():
7989
fp32 = FP32()
8090
uint8 = UINT8()
8191

82-
def __init__(self, *, preprocess_callback = None, postprocess_callback, source, sink : str, ffmpeg_output_width = None, ffmpeg_output_height = None, output_backend :OutputBackend = OutputBackends.ffmpeg, verbose : bool = False, exact : bool = True, benchmark : bool = False):
92+
def __init__(self, *, preprocess_callback = None, postprocess_callback, source, sink: str, ffmpeg_output_width = None, ffmpeg_output_height = None, output_backend: OutputBackend = OutputBackends.ffmpeg, verbose: bool = False, exact: bool = True, benchmark: bool = False, max_inflight_packets: int = 4):
8393
"""
8494
Parameters
8595
----------
@@ -114,6 +124,7 @@ def __init__(self, *, preprocess_callback = None, postprocess_callback, source,
114124
self.benchmark = benchmark
115125

116126
self.pq = queue.PriorityQueue()
127+
self.req_q = queue.Queue(max_inflight_packets)
117128

118129
def grab_frame(self):
119130
success, frame = self.cap.read()
@@ -132,18 +143,24 @@ def grab_frame(self):
132143
dropped_frames = 0
133144
frames = 0
134145
def callback(self, frame, i, timestamp, result, error):
146+
if error is not None:
147+
if self.benchmark:
148+
self.dropped_frames += 1
149+
if self.verbose:
150+
print(error)
151+
if i == None:
152+
i = result.get_response().parameters["OVMS_MP_TIMESTAMP"].int64_param
153+
if timestamp == None:
154+
timestamp = result.get_response().parameters["OVMS_MP_TIMESTAMP"].int64_param
135155
frame = self.postprocess_callback(frame, result)
136156
self.pq.put((i, frame, timestamp))
137-
if error is not None and self.verbose == True:
138-
print(error)
157+
self.req_q.get()
139158

140159
def display(self):
141160
i = 0
142161
while True:
143-
if self.pq.empty():
144-
continue
145162
entry = self.pq.get()
146-
if (entry[0] == i and self.exact) or (entry[0] > i and self.exact is not True):
163+
if (entry[0] == i and self.exact and self.streaming_api is not True) or (entry[0] > i and (self.exact is not True or self.streaming_api is True)):
147164
if isinstance(entry[1], str) and entry[1] == "EOS":
148165
break
149166
frame = entry[1]
@@ -161,8 +178,10 @@ def display(self):
161178
elif self.exact:
162179
self.pq.put(entry)
163180

181+
def get_timestamp(self) -> int:
182+
return int(cv2.getTickCount() / cv2.getTickFrequency() * 1e6)
164183

165-
def start(self, *, ovms_address : str, input_name : str, model_name : str, datatype : Datatype = FP32(), batch = True, limit_stream_duration : int = 0, limit_frames : int = 0):
184+
def start(self, *, ovms_address : str, input_name : str, model_name : str, datatype : Datatype = FP32(), batch = True, limit_stream_duration : int = 0, limit_frames : int = 0, streaming_api: bool = False):
166185
"""
167186
Parameters
168187
----------
@@ -180,12 +199,15 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
180199
Limits how long client could run
181200
limit_frames : int
182201
Limits how many frames should be processed
202+
streaming_api : bool
203+
Use experimental streaming endpoint
183204
"""
184205

185-
self.cap = cv2.VideoCapture(self.source, cv2.CAP_ANY)
206+
self.cap = cv2.VideoCapture(int(self.source) if len(self.source) == 1 and self.source[0].isdigit() else self.source, cv2.CAP_ANY)
186207
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 0)
187208
fps = self.cap.get(cv2.CAP_PROP_FPS)
188209
triton_client = grpcclient.InferenceServerClient(url=ovms_address, verbose=False)
210+
self.streaming_api = streaming_api
189211

190212
display_th = threading.Thread(target=self.display)
191213
display_th.start()
@@ -199,26 +221,37 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
199221
if self.height is None:
200222
self.height = np_test_frame.shape[0]
201223
self.output_backend.init(self.sink, fps, self.width, self.height)
224+
225+
if streaming_api:
226+
triton_client.start_stream(partial(self.callback, None, None, None))
202227

203-
i = 0
228+
frame_number = 0
204229
total_time_start = time.time()
205-
while not self.force_exit:
206-
timestamp = time.time()
207-
frame = self.grab_frame()
208-
if frame is not None:
209-
np_frame = np.array([frame], dtype=datatype.dtype()) if batch else np.array(frame, dtype=datatype.dtype())
210-
inputs=[grpcclient.InferInput(input_name, np_frame.shape, datatype.string())]
211-
inputs[0].set_data_from_numpy(np_frame)
212-
triton_client.async_infer(
213-
model_name=model_name,
214-
callback=partial(self.callback, frame, i, timestamp),
215-
inputs=inputs)
216-
i += 1
217-
if limit_stream_duration > 0 and time.time() - total_time_start > limit_stream_duration:
218-
break
219-
if limit_frames > 0 and i > limit_frames:
220-
break
221-
self.pq.put((i, "EOS"))
230+
try:
231+
while not self.force_exit:
232+
self.req_q.put(frame_number)
233+
timestamp = time.time()
234+
frame = self.grab_frame()
235+
if frame is not None:
236+
np_frame = np.array([frame], dtype=datatype.dtype()) if batch else np.array(frame, dtype=datatype.dtype())
237+
inputs=[grpcclient.InferInput(input_name, np_frame.shape, datatype.string())]
238+
inputs[0].set_data_from_numpy(np_frame)
239+
if streaming_api:
240+
triton_client.async_stream_infer(model_name=model_name, inputs=inputs, parameters={"OVMS_MP_TIMESTAMP":self.get_timestamp()}, request_id=str(frame_number))
241+
else:
242+
triton_client.async_infer(
243+
model_name=model_name,
244+
callback=partial(self.callback, frame, frame_number, timestamp),
245+
inputs=inputs)
246+
frame_number += 1
247+
if limit_stream_duration > 0 and time.time() - total_time_start > limit_stream_duration:
248+
break
249+
if limit_frames > 0 and frame_number > limit_frames:
250+
break
251+
finally:
252+
self.pq.put((frame_number, "EOS"))
253+
if streaming_api:
254+
triton_client.stop_stream()
222255
sent_all_frames = time.time() - total_time_start
223256

224257

@@ -227,4 +260,4 @@ def start(self, *, ovms_address : str, input_name : str, model_name : str, datat
227260
self.output_backend.release()
228261
total_time = time.time() - total_time_start
229262
if self.benchmark:
230-
print(f"{{\"inference_time\": {sum(self.inference_time)/i}, \"dropped_frames\": {self.dropped_frames}, \"frames\": {self.frames}, \"fps\": {self.frames/total_time}, \"total_time\": {total_time}, \"sent_all_frames\": {sent_all_frames}}}")
263+
print(f"{{\"inference_time\": {sum(self.inference_time)/frame_number}, \"dropped_frames\": {self.dropped_frames}, \"frames\": {self.frames}, \"fps\": {self.frames/total_time}, \"total_time\": {total_time}, \"sent_all_frames\": {sent_all_frames}}}")

demos/mediapipe/holistic_tracking/README.md

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ This guide shows how to implement [MediaPipe](../../../docs/mediapipe.md) graph
44

55
Example usage of graph that accepts Mediapipe::ImageFrame as a input:
66

7-
The demo is based on the [upstream Mediapipe holistic demo](https://github.com/google/mediapipe/blob/master/docs/solutions/holistic.md)
8-
and [Mediapipe Iris demo](https://github.com/google/mediapipe/blob/master/docs/solutions/iris.md)
7+
The demo is based on the [upstream Mediapipe holistic demo](https://github.com/google/mediapipe/blob/master/docs/solutions/holistic.md).
98

109
## Prepare the server deployment
1110

@@ -82,23 +81,6 @@ Results saved to :image_0.jpg
8281
## Output image
8382
![output](output_image.jpg)
8483

85-
## Run client application for iris tracking
86-
In a similar way can be executed the iris image analysis:
87-
88-
```bash
89-
python mediapipe_holistic_tracking.py --graph_name irisTracking --images_list input_images.txt --grpc_port 9000
90-
Running demo application.
91-
Start processing:
92-
Graph name: irisTracking
93-
(640, 960, 3)
94-
Iteration 0; Processing time: 77.03 ms; speed 12.98 fps
95-
Results saved to :image_0.jpg
96-
```
97-
98-
## Output image
99-
![output](output_image1.jpg)
100-
101-
10284

10385
## RTSP Client
10486
Mediapipe graph can be used for remote analysis of individual images but the client can use it for a complete video stream processing.

demos/mediapipe/holistic_tracking/rtsp_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,5 @@ def postprocess(frame, result):
5151
exact = True
5252

5353
client = StreamClient(postprocess_callback = postprocess, preprocess_callback=preprocess, output_backend=backend, source=args.input_stream, sink=args.output_stream, exact=exact, benchmark=args.benchmark, verbose=args.verbose)
54-
client.start(ovms_address=args.grpc_address, input_name=args.input_name, model_name=args.model_name, datatype = StreamClient.Datatypes.uint8, batch = False, limit_stream_duration = args.limit_stream_duration, limit_frames = args.limit_frames)
54+
client.start(ovms_address=args.grpc_address, input_name=args.input_name, model_name=args.model_name, datatype = StreamClient.Datatypes.uint8, batch = False, limit_stream_duration = args.limit_stream_duration, limit_frames = args.limit_frames, streaming_api=True)
5555

0 commit comments

Comments
 (0)