Skip to content

Commit 8bac797

Browse files
author
David Eigen
committed
move ensure_urls_downloaded into model class with enable flag
1 parent 00ebae0 commit 8bac797

File tree

4 files changed

+30
-21
lines changed

4 files changed

+30
-21
lines changed

clarifai/runners/models/base_typed_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from clarifai_grpc.grpc.api.service_pb2 import PostModelOutputsRequest
77
from google.protobuf import json_format
88

9+
from clarifai.runners.utils.stream_utils import readahead
10+
from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded
11+
912
from ..utils.data_handler import InputDataHandler, OutputDataHandler
1013
from .model_class import ModelClass
1114

@@ -46,13 +49,17 @@ def convert_output_to_proto(self, outputs: list):
4649

4750
def predict_wrapper(
4851
self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
52+
if self.download_request_urls:
53+
ensure_urls_downloaded(request)
4954
list_dict_input, inference_params = self.parse_input_request(request)
5055
outputs = self.predict(list_dict_input, inference_parameters=inference_params)
5156
return self.convert_output_to_proto(outputs)
5257

5358
def generate_wrapper(
5459
self, request: PostModelOutputsRequest) -> Iterator[service_pb2.MultiOutputResponse]:
5560
list_dict_input, inference_params = self.parse_input_request(request)
61+
if self.download_request_urls:
62+
ensure_urls_downloaded(request)
5663
outputs = self.generate(list_dict_input, inference_parameters=inference_params)
5764
for output in outputs:
5865
yield self.convert_output_to_proto(output)
@@ -69,6 +76,8 @@ def stream_wrapper(self, request: Iterator[PostModelOutputsRequest]
6976
first_request = next(request)
7077
_, inference_params = self.parse_input_request(first_request)
7178
request_iterator = itertools.chain([first_request], request)
79+
if self.download_request_urls:
80+
request_iterator = readahead(map(ensure_urls_downloaded, request_iterator))
7281
outputs = self.stream(self._preprocess_stream(request_iterator), inference_params)
7382
for output in outputs:
7483
yield self.convert_output_to_proto(output)

clarifai/runners/models/model_class.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,41 @@
33

44
from clarifai_grpc.grpc.api import service_pb2
55

6+
from clarifai.runners.utils.stream_utils import readahead
7+
from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded
8+
69

710
class ModelClass(ABC):
811

12+
download_request_urls = True
13+
914
def predict_wrapper(
1015
self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
1116
"""This method is used for input/output proto data conversion"""
17+
# Download any urls that are not already bytes.
18+
if self.download_request_urls:
19+
ensure_urls_downloaded(request)
20+
1221
return self.predict(request)
1322

1423
def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest
1524
) -> Iterator[service_pb2.MultiOutputResponse]:
1625
"""This method is used for input/output proto data conversion and yield outcome"""
26+
# Download any urls that are not already bytes.
27+
if self.download_request_urls:
28+
ensure_urls_downloaded(request)
29+
1730
return self.generate(request)
1831

19-
def stream_wrapper(self, request: service_pb2.PostModelOutputsRequest
32+
def stream_wrapper(self, request_stream: Iterator[service_pb2.PostModelOutputsRequest]
2033
) -> Iterator[service_pb2.MultiOutputResponse]:
2134
"""This method is used for input/output proto data conversion and yield outcome"""
22-
return self.stream(request)
35+
36+
# Download any urls that are not already bytes.
37+
if self.download_request_urls:
38+
request_stream = readahead(map(ensure_urls_downloaded, request_stream))
39+
40+
return self.stream(request_stream)
2341

2442
@abstractmethod
2543
def load_model(self):

clarifai/runners/models/model_runner.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from clarifai_protocol import BaseRunner
77
from clarifai_protocol.utils.health import HealthProbeRequestHandler
8-
from ..utils.url_fetcher import ensure_urls_downloaded
98

109
from .model_class import ModelClass
1110

@@ -79,7 +78,6 @@ def runner_item_predict(self,
7978
if not runner_item.HasField('post_model_outputs_request'):
8079
raise Exception("Unexpected work item type: {}".format(runner_item))
8180
request = runner_item.post_model_outputs_request
82-
ensure_urls_downloaded(request)
8381

8482
resp = self.model.predict_wrapper(request)
8583
successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
@@ -109,7 +107,6 @@ def runner_item_generate(
109107
if not runner_item.HasField('post_model_outputs_request'):
110108
raise Exception("Unexpected work item type: {}".format(runner_item))
111109
request = runner_item.post_model_outputs_request
112-
ensure_urls_downloaded(request)
113110

114111
for resp in self.model.generate_wrapper(request):
115112
successes = []
@@ -169,5 +166,4 @@ def pmo_iterator(runner_item_iterator):
169166
for runner_item in runner_item_iterator:
170167
if not runner_item.HasField('post_model_outputs_request'):
171168
raise Exception("Unexpected work item type: {}".format(runner_item))
172-
ensure_urls_downloaded(runner_item.post_model_outputs_request)
173169
yield runner_item.post_model_outputs_request

clarifai/runners/models/model_servicer.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc
44
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
55

6-
from ..utils.stream_utils import readahead
7-
from ..utils.url_fetcher import ensure_urls_downloaded
8-
96

107
class ModelServicer(service_pb2_grpc.V2Servicer):
118
"""
@@ -27,9 +24,6 @@ def PostModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
2724
returns an output.
2825
"""
2926

30-
# Download any urls that are not already bytes.
31-
ensure_urls_downloaded(request)
32-
3327
try:
3428
return self.model.predict_wrapper(request)
3529
except Exception as e:
@@ -46,9 +40,6 @@ def GenerateModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
4640
This is the method that will be called when the servicer is run. It takes in an input and
4741
returns an output.
4842
"""
49-
# Download any urls that are not already bytes.
50-
ensure_urls_downloaded(request)
51-
5243
try:
5344
return self.model.generate_wrapper(request)
5445
except Exception as e:
@@ -66,13 +57,8 @@ def StreamModelOutputs(self,
6657
This is the method that will be called when the servicer is run. It takes in an input and
6758
returns an output.
6859
"""
69-
70-
# Download any urls that are not already bytes.
71-
def _download_urls_stream(requests):
72-
return readahead(map(ensure_urls_downloaded, requests))
73-
7460
try:
75-
return self.model_class.stream(_download_urls_stream(request))
61+
return self.model_class.stream(request)
7662
except Exception as e:
7763
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
7864
code=status_code_pb2.MODEL_PREDICTION_FAILED,

0 commit comments

Comments
 (0)