Skip to content

Commit 5c7af56

Browse files
committed
Address review comments
1 parent 3d3b820 commit 5c7af56

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

ami/ml/models/pipeline.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,24 +194,23 @@ def process_images(
194194
task_logger.info(f"Sending {len(images)} images to Pipeline {pipeline}")
195195
urls = [source_image.public_url() for source_image in images if source_image.public_url()]
196196

197-
source_images: list[SourceImageRequest] = []
197+
source_image_requests: list[SourceImageRequest] = []
198198
detection_requests: list[DetectionRequest] = []
199199

200200
for source_image, url in zip(images, urls):
201201
if url:
202-
source_images.append(
203-
SourceImageRequest(
204-
id=str(source_image.pk),
205-
url=url,
206-
)
202+
source_image_request = SourceImageRequest(
203+
id=str(source_image.pk),
204+
url=url,
207205
)
208-
# Only re-process detections created by the pipeline's detector
206+
source_image_requests.append(source_image_request)
207+
# Re-process all existing detections if they exist
209208
for detection in source_image.detections.all():
210209
bbox = detection.get_bbox()
211210
if bbox and detection.detection_algorithm:
212211
detection_requests.append(
213212
DetectionRequest(
214-
source_image=source_images[-1],
213+
source_image=source_image_request,
215214
bbox=bbox,
216215
crop_image_url=detection.url(),
217216
algorithm=AlgorithmReference(
@@ -231,7 +230,7 @@ def process_images(
231230

232231
request_data = PipelineRequest(
233232
pipeline=pipeline.slug,
234-
source_images=source_images,
233+
source_images=source_image_requests,
235234
config=config,
236235
detections=detection_requests,
237236
)
@@ -253,7 +252,8 @@ def process_images(
253252
pipeline=pipeline.slug,
254253
total_time=0,
255254
source_images=[
256-
SourceImageResponse(id=source_image.id, url=source_image.url) for source_image in source_images
255+
SourceImageResponse(id=source_image_request.id, url=source_image_request.url)
256+
for source_image_request in source_image_requests
257257
],
258258
detections=[],
259259
errors=msg,
@@ -992,7 +992,7 @@ def collect_images(
992992
)
993993

994994
def choose_processing_service_for_pipeline(
995-
self, job_id: int, pipeline_name: str, project_id: int
995+
self, job_id: int | None, pipeline_name: str, project_id: int
996996
) -> ProcessingService:
997997
# @TODO use the cached `last_checked_latency` and a max age to avoid checking every time
998998

ami/ml/tests.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,10 @@ def test_pipeline_reprocessing(self):
207207
"detection_algorithm",
208208
"detection_algorithm__category_map",
209209
)
210+
initial_detection_ids = sorted([det.pk for det in detections])
210211
assert detections.count() > 0
211-
initial_num_detections = detections.count()
212212

213-
# Reprocess the same images
213+
# Reprocess the same images using a different pipeline
214214
pipeline = self.processing_service_instance.pipelines.all().get(slug="constant")
215215
pipeline_response = pipeline.process_images(self.test_images, project_id=self.project.pk)
216216
reprocessed_results = save_results(pipeline_response, return_created=True)
@@ -222,7 +222,15 @@ def test_pipeline_reprocessing(self):
222222
"detection_algorithm",
223223
"detection_algorithm__category_map",
224224
)
225-
assert initial_num_detections == detections.count(), "Expected no new detections to be created."
225+
226+
# Check detections were re-processed, and not re-created
227+
reprocessed_detection_ids = sorted([det.pk for det in detections])
228+
assert initial_detection_ids == reprocessed_detection_ids, (
229+
"Expected the same detections to be returned after reprocessing with a different pipeline, "
230+
f"but found {initial_detection_ids} != {reprocessed_detection_ids}"
231+
)
232+
233+
# The constant pipeline produces 1 classification per detection
226234
for detection in detections:
227235
assert (
228236
detection.classifications.count() == 3

0 commit comments

Comments
 (0)