@@ -57,9 +57,9 @@ def predict_wrapper(
57
57
58
58
def generate_wrapper (
59
59
self , request : PostModelOutputsRequest ) -> Iterator [service_pb2 .MultiOutputResponse ]:
60
- list_dict_input , inference_params = self .parse_input_request (request )
61
60
if self .download_request_urls :
62
61
ensure_urls_downloaded (request )
62
+ list_dict_input , inference_params = self .parse_input_request (request )
63
63
outputs = self .generate (list_dict_input , inference_parameters = inference_params )
64
64
for output in outputs :
65
65
yield self .convert_output_to_proto (output )
@@ -71,13 +71,13 @@ def _preprocess_stream(
71
71
input_data , _ = self .parse_input_request (req )
72
72
yield input_data
73
73
74
- def stream_wrapper (self , request : Iterator [PostModelOutputsRequest ]
74
+ def stream_wrapper (self , request_iterator : Iterator [PostModelOutputsRequest ]
75
75
) -> Iterator [service_pb2 .MultiOutputResponse ]:
76
- first_request = next (request )
77
- _ , inference_params = self .parse_input_request (first_request )
78
- request_iterator = itertools .chain ([first_request ], request )
79
76
if self .download_request_urls :
80
77
request_iterator = readahead (map (ensure_urls_downloaded , request_iterator ))
78
+ first_request = next (request_iterator )
79
+ _ , inference_params = self .parse_input_request (first_request )
80
+ request_iterator = itertools .chain ([first_request ], request_iterator )
81
81
outputs = self .stream (self ._preprocess_stream (request_iterator ), inference_params )
82
82
for output in outputs :
83
83
yield self .convert_output_to_proto (output )
0 commit comments