Skip to content

Commit 2ebaefe

Browse files
committed
fix: return category maps from all models in example processing service
1 parent 41ef9e1 commit 2ebaefe

File tree

2 files changed

+106
-41
lines changed

2 files changed

+106
-41
lines changed

processing_services/example/api/algorithms.py

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,28 @@ def compile(self):
4242
def run(self, inputs: list[SourceImage] | list[Detection]) -> list[Detection]:
4343
raise NotImplementedError("Subclasses must implement the run method")
4444

45-
algorithm_config_response = AlgorithmConfigResponse(
46-
name="Base Algorithm",
47-
key="base",
48-
task_type="base",
49-
description="A base class for all algorithms.",
50-
version=1,
51-
version_name="v1",
52-
category_map=None,
53-
)
45+
def get_category_map(self) -> AlgorithmCategoryMapResponse:
46+
return AlgorithmCategoryMapResponse(
47+
data=[],
48+
labels=[],
49+
version="v1",
50+
description="A model without labels.",
51+
uri=None,
52+
)
53+
54+
def get_algorithm_config_response(self) -> AlgorithmConfigResponse:
55+
return AlgorithmConfigResponse(
56+
name="Base Algorithm",
57+
key="base",
58+
task_type="base",
59+
description="A base class for all algorithms.",
60+
version=1,
61+
version_name="v1",
62+
category_map=self.get_category_map(),
63+
)
64+
65+
def __init__(self):
66+
self.algorithm_config_response = self.get_algorithm_config_response()
5467

5568

5669
class ZeroShotObjectDetector(Algorithm):
@@ -141,33 +154,45 @@ def run(self, source_images: list[SourceImage], intermediate=False) -> list[Dete
141154

142155
return detector_responses
143156

144-
algorithm_config_response = AlgorithmConfigResponse(
145-
name="Zero Shot Object Detector",
146-
key="zero-shot-object-detector",
147-
task_type="detection",
148-
description=(
149-
"Huggingface Zero Shot Object Detection model."
150-
"Produces both a bounding box and a candidate label classification for each detection."
151-
),
152-
version=1,
153-
version_name="v1",
154-
category_map=None,
155-
)
157+
def get_category_map(self) -> AlgorithmCategoryMapResponse:
158+
return AlgorithmCategoryMapResponse(
159+
data=[{"index": i, "label": label} for i, label in enumerate(self.candidate_labels)],
160+
labels=self.candidate_labels,
161+
version="v1",
162+
description="Candidate labels used for zero-shot object detection.",
163+
uri=None,
164+
)
165+
166+
def get_algorithm_config_response(self) -> AlgorithmConfigResponse:
167+
return AlgorithmConfigResponse(
168+
name="Zero Shot Object Detector",
169+
key="zero-shot-object-detector",
170+
task_type="detection",
171+
description=(
172+
"Huggingface Zero Shot Object Detection model."
173+
"Produces both a bounding box and a candidate label classification for each detection."
174+
),
175+
version=1,
176+
version_name="v1",
177+
category_map=self.get_category_map(),
178+
)
156179

157180

158181
class HFImageClassifier(Algorithm):
159182
"""
160183
A local classifier that uses the Hugging Face pipeline to classify images.
161184
"""
162185

186+
model_name: str = "google/vit-base-patch16-224" # Vision Transformer model trained on ImageNet-1k
187+
163188
def compile(self):
164189
saved_models_key = "hf_image_classifier" # generate a key for each uniquely compiled algorithm
165190

166191
if saved_models_key not in SAVED_MODELS:
167192
from transformers import pipeline
168193

169194
logger.info(f"Compiling {self.algorithm_config_response.name} from scratch...")
170-
self.model = pipeline("image-classification", model="google/vit-base-patch16-224")
195+
self.model = pipeline("image-classification", model=self.model_name, device=get_best_device())
171196
SAVED_MODELS[saved_models_key] = self.model
172197
else:
173198
logger.info(f"Using saved model for {self.algorithm_config_response.name}...")
@@ -216,15 +241,55 @@ def run(self, detections: list[Detection]) -> list[Detection]:
216241

217242
return detections_to_return
218243

219-
algorithm_config_response = AlgorithmConfigResponse(
220-
name="HF Image Classifier",
221-
key="hf-image-classifier",
222-
task_type="classification",
223-
description="HF ViT for image classification.",
224-
version=1,
225-
version_name="v1",
226-
category_map=None,
227-
)
244+
def get_category_map(self) -> AlgorithmCategoryMapResponse:
245+
"""
246+
Extract the category map from the model.
247+
Returns an AlgorithmCategoryMapResponse with labels, data, and model information.
248+
"""
249+
from transformers.models.auto.configuration_auto import AutoConfig
250+
251+
logger.info(f"Loading configuration for {self.model_name}")
252+
config = AutoConfig.from_pretrained(self.model_name)
253+
254+
# Extract label information
255+
if not hasattr(config, "id2label") or not config.id2label:
256+
raise ValueError(
257+
f"Cannot create category map for model {self.model_name}, no id2label mapping found in config"
258+
)
259+
else:
260+
# Sort labels by index
261+
# Ensure keys are strings for consistent access
262+
id2label: dict[str, str] = {str(k): v for k, v in config.id2label.items()}
263+
indices = sorted([int(k) for k in id2label.keys()])
264+
265+
# Create labels and data
266+
labels = [id2label[str(i)] for i in indices]
267+
data = [{"label": label, "index": idx} for idx, label in zip(indices, labels)]
268+
269+
# Build description
270+
description_text = (
271+
f"Vision Transformer model trained on ImageNet-1k. "
272+
f"Contains {len(labels)} object classes. Model: {self.model_name}"
273+
)
274+
275+
return AlgorithmCategoryMapResponse(
276+
data=data,
277+
labels=labels,
278+
version="ImageNet-1k",
279+
description=description_text,
280+
uri=f"https://huggingface.co/{self.model_name}",
281+
)
282+
283+
def get_algorithm_config_response(self) -> AlgorithmConfigResponse:
284+
return AlgorithmConfigResponse(
285+
name="HF Image Classifier",
286+
key="hf-image-classifier",
287+
task_type="classification",
288+
description="HF ViT for image classification.",
289+
version=1,
290+
version_name="v1",
291+
category_map=self.get_category_map(),
292+
)
228293

229294

230295
class RandomSpeciesClassifier(Algorithm):

processing_services/example/api/pipelines.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ class ZeroShotHFClassifierPipeline(Pipeline):
153153
description=("Zero Shot Object Detector with HF image classifier."),
154154
version=1,
155155
algorithms=[
156-
ZeroShotObjectDetector.algorithm_config_response,
157-
HFImageClassifier.algorithm_config_response,
156+
ZeroShotObjectDetector().algorithm_config_response,
157+
HFImageClassifier().algorithm_config_response,
158158
],
159159
)
160160

@@ -167,7 +167,7 @@ def get_stages(self) -> list[Algorithm]:
167167
zero_shot_object_detector.candidate_labels = self.request_config["candidate_labels"]
168168
self.config.algorithms = [
169169
zero_shot_object_detector.algorithm_config_response,
170-
HFImageClassifier.algorithm_config_response,
170+
HFImageClassifier().algorithm_config_response,
171171
]
172172

173173
return [zero_shot_object_detector, HFImageClassifier()]
@@ -212,7 +212,7 @@ class ZeroShotObjectDetectorPipeline(Pipeline):
212212
slug="zero-shot-object-detector-pipeline",
213213
description=("Zero shot object detector (bbox and classification)."),
214214
version=1,
215-
algorithms=[ZeroShotObjectDetector.algorithm_config_response],
215+
algorithms=[ZeroShotObjectDetector().algorithm_config_response],
216216
)
217217

218218
def get_stages(self) -> list[Algorithm]:
@@ -254,8 +254,8 @@ class ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline(Pipeline):
254254
description=("HF zero shot object detector with random species classifier."),
255255
version=1,
256256
algorithms=[
257-
ZeroShotObjectDetector.algorithm_config_response,
258-
RandomSpeciesClassifier.algorithm_config_response,
257+
ZeroShotObjectDetector().algorithm_config_response,
258+
RandomSpeciesClassifier().algorithm_config_response,
259259
],
260260
)
261261

@@ -266,7 +266,7 @@ def get_stages(self) -> list[Algorithm]:
266266

267267
self.config.algorithms = [
268268
zero_shot_object_detector.algorithm_config_response,
269-
RandomSpeciesClassifier.algorithm_config_response,
269+
RandomSpeciesClassifier().algorithm_config_response,
270270
]
271271

272272
return [zero_shot_object_detector, RandomSpeciesClassifier()]
@@ -307,8 +307,8 @@ class ZeroShotObjectDetectorWithConstantClassifierPipeline(Pipeline):
307307
description=("HF zero shot object detector with constant classifier."),
308308
version=1,
309309
algorithms=[
310-
ZeroShotObjectDetector.algorithm_config_response,
311-
ConstantClassifier.algorithm_config_response,
310+
ZeroShotObjectDetector().algorithm_config_response,
311+
ConstantClassifier().algorithm_config_response,
312312
],
313313
)
314314

@@ -319,7 +319,7 @@ def get_stages(self) -> list[Algorithm]:
319319

320320
self.config.algorithms = [
321321
zero_shot_object_detector.algorithm_config_response,
322-
ConstantClassifier.algorithm_config_response,
322+
ConstantClassifier().algorithm_config_response,
323323
]
324324

325325
return [zero_shot_object_detector, ConstantClassifier()]

0 commit comments

Comments
 (0)