Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions clarifai/runners/models/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Iterator

from clarifai_grpc.grpc.api import service_pb2
Expand All @@ -6,6 +7,7 @@
from clarifai_protocol.utils.health import HealthProbeRequestHandler

from clarifai.client.auth.helper import ClarifaiAuthHelper
from clarifai.utils.logging import get_req_id_from_context, logger

from ..utils.url_fetcher import ensure_urls_downloaded
from .model_class import ModelClass
Expand Down Expand Up @@ -106,13 +108,21 @@ def runner_item_predict(
raise Exception("Unexpected work item type: {}".format(runner_item))
request = runner_item.post_model_outputs_request
ensure_urls_downloaded(request, auth_helper=self._auth_helper)
start_time = time.time()
req_id = get_req_id_from_context()
status_str = "UNKNOWN"
# Endpoint is always POST /v2/.../outputs for this runner
endpoint = "POST /v2/.../outputs "
Copy link
Preview

Copilot AI Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The endpoint string contains trailing whitespace which appears to be for formatting alignment. Consider using a consistent string format without manual padding, or use a constant if alignment is needed for display purposes.

Suggested change
endpoint = "POST /v2/.../outputs "
endpoint = "POST /v2/.../outputs"

Copilot uses AI. Check for mistakes.


resp = self.model.predict_wrapper(request)
# if we have any non-successful code already it's an error we can return.
if (
resp.status.code != status_code_pb2.SUCCESS
and resp.status.code != status_code_pb2.ZERO
):
status_str = f"{resp.status.code} ERROR"
duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
return service_pb2.RunnerItemOutput(multi_output_response=resp)
successes = []
for output in resp.outputs:
Expand All @@ -126,18 +136,23 @@ def runner_item_predict(
code=status_code_pb2.SUCCESS,
description="Success",
)
status_str = "200 OK"
elif any(successes):
status = status_pb2.Status(
code=status_code_pb2.MIXED_STATUS,
description="Mixed Status",
)
status_str = "207 MIXED"
else:
status = status_pb2.Status(
code=status_code_pb2.FAILURE,
description="Failed",
)
status_str = "500 FAIL"

resp.status.CopyFrom(status)
duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
return service_pb2.RunnerItemOutput(multi_output_response=resp)

def runner_item_generate(
Expand All @@ -150,12 +165,21 @@ def runner_item_generate(
request = runner_item.post_model_outputs_request
ensure_urls_downloaded(request, auth_helper=self._auth_helper)

# --- Live logging additions ---
start_time = time.time()
req_id = get_req_id_from_context()
status_str = "UNKNOWN"
endpoint = "POST /v2/.../outputs/generate"
Copy link
Preview

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The hardcoded endpoint string uses '...' which is unclear. Consider extracting these endpoint strings to constants to avoid duplication and improve maintainability.

Suggested change
endpoint = "POST /v2/.../outputs/generate"
endpoint = self.ENDPOINT_GENERATE

Copilot uses AI. Check for mistakes.


for resp in self.model.generate_wrapper(request):
# if we have any non-successful code already it's an error we can return.
if (
resp.status.code != status_code_pb2.SUCCESS
and resp.status.code != status_code_pb2.ZERO
):
status_str = f"{resp.status.code} ERROR"
duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
yield service_pb2.RunnerItemOutput(multi_output_response=resp)
continue
successes = []
Expand All @@ -170,30 +194,44 @@ def runner_item_generate(
code=status_code_pb2.SUCCESS,
description="Success",
)
status_str = "200 OK"
elif any(successes):
status = status_pb2.Status(
code=status_code_pb2.MIXED_STATUS,
description="Mixed Status",
)
status_str = "207 MIXED"
else:
status = status_pb2.Status(
code=status_code_pb2.FAILURE,
description="Failed",
)
status_str = "500 FAIL"
resp.status.CopyFrom(status)

yield service_pb2.RunnerItemOutput(multi_output_response=resp)

duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")

def runner_item_stream(
self, runner_item_iterator: Iterator[service_pb2.RunnerItem]
) -> Iterator[service_pb2.RunnerItemOutput]:
# Call the generate() method the underlying model implements.
start_time = time.time()
req_id = get_req_id_from_context()
status_str = "UNKNOWN"
endpoint = "POST /v2/.../outputs/stream "

for resp in self.model.stream_wrapper(pmo_iterator(runner_item_iterator)):
# if we have any non-successful code already it's an error we can return.
if (
resp.status.code != status_code_pb2.SUCCESS
and resp.status.code != status_code_pb2.ZERO
):
status_str = f"{resp.status.code} ERROR"
duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")
yield service_pb2.RunnerItemOutput(multi_output_response=resp)
continue
successes = []
Expand All @@ -208,20 +246,26 @@ def runner_item_stream(
code=status_code_pb2.SUCCESS,
description="Success",
)
status_str = "200 OK"
elif any(successes):
status = status_pb2.Status(
code=status_code_pb2.MIXED_STATUS,
description="Mixed Status",
)
status_str = "207 MIXED"
else:
status = status_pb2.Status(
code=status_code_pb2.FAILURE,
description="Failed",
)
status_str = "500 FAIL"
resp.status.CopyFrom(status)

yield service_pb2.RunnerItemOutput(multi_output_response=resp)

duration_ms = (time.time() - start_time) * 1000
logger.info(f"{endpoint} | {status_str} | {duration_ms:.2f}ms | req_id={req_id}")


def pmo_iterator(runner_item_iterator, auth_helper=None):
for runner_item in runner_item_iterator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ workflow:
description: Custom crop model
output_info:
params:
margin: 1.33
margin: 1.3
node_inputs:
- node_id: detector
Loading