@@ -42,15 +42,28 @@ def compile(self):
42
42
def run (self , inputs : list [SourceImage ] | list [Detection ]) -> list [Detection ]:
43
43
raise NotImplementedError ("Subclasses must implement the run method" )
44
44
45
- algorithm_config_response = AlgorithmConfigResponse (
46
- name = "Base Algorithm" ,
47
- key = "base" ,
48
- task_type = "base" ,
49
- description = "A base class for all algorithms." ,
50
- version = 1 ,
51
- version_name = "v1" ,
52
- category_map = None ,
53
- )
45
+ def get_category_map (self ) -> AlgorithmCategoryMapResponse :
46
+ return AlgorithmCategoryMapResponse (
47
+ data = [],
48
+ labels = [],
49
+ version = "v1" ,
50
+ description = "A model without labels." ,
51
+ uri = None ,
52
+ )
53
+
54
+ def get_algorithm_config_response (self ) -> AlgorithmConfigResponse :
55
+ return AlgorithmConfigResponse (
56
+ name = "Base Algorithm" ,
57
+ key = "base" ,
58
+ task_type = "base" ,
59
+ description = "A base class for all algorithms." ,
60
+ version = 1 ,
61
+ version_name = "v1" ,
62
+ category_map = self .get_category_map (),
63
+ )
64
+
65
+ def __init__ (self ):
66
+ self .algorithm_config_response = self .get_algorithm_config_response ()
54
67
55
68
56
69
class ZeroShotObjectDetector (Algorithm ):
@@ -141,33 +154,45 @@ def run(self, source_images: list[SourceImage], intermediate=False) -> list[Dete
141
154
142
155
return detector_responses
143
156
144
- algorithm_config_response = AlgorithmConfigResponse (
145
- name = "Zero Shot Object Detector" ,
146
- key = "zero-shot-object-detector" ,
147
- task_type = "detection" ,
148
- description = (
149
- "Huggingface Zero Shot Object Detection model."
150
- "Produces both a bounding box and a candidate label classification for each detection."
151
- ),
152
- version = 1 ,
153
- version_name = "v1" ,
154
- category_map = None ,
155
- )
157
+ def get_category_map (self ) -> AlgorithmCategoryMapResponse :
158
+ return AlgorithmCategoryMapResponse (
159
+ data = [{"index" : i , "label" : label } for i , label in enumerate (self .candidate_labels )],
160
+ labels = self .candidate_labels ,
161
+ version = "v1" ,
162
+ description = "Candidate labels used for zero-shot object detection." ,
163
+ uri = None ,
164
+ )
165
+
166
+ def get_algorithm_config_response (self ) -> AlgorithmConfigResponse :
167
+ return AlgorithmConfigResponse (
168
+ name = "Zero Shot Object Detector" ,
169
+ key = "zero-shot-object-detector" ,
170
+ task_type = "detection" ,
171
+ description = (
172
+ "Huggingface Zero Shot Object Detection model."
173
+ "Produces both a bounding box and a candidate label classification for each detection."
174
+ ),
175
+ version = 1 ,
176
+ version_name = "v1" ,
177
+ category_map = self .get_category_map (),
178
+ )
156
179
157
180
158
181
class HFImageClassifier (Algorithm ):
159
182
"""
160
183
A local classifier that uses the Hugging Face pipeline to classify images.
161
184
"""
162
185
186
+ model_name : str = "google/vit-base-patch16-224" # Vision Transformer model trained on ImageNet-1k
187
+
163
188
def compile (self ):
164
189
saved_models_key = "hf_image_classifier" # generate a key for each uniquely compiled algorithm
165
190
166
191
if saved_models_key not in SAVED_MODELS :
167
192
from transformers import pipeline
168
193
169
194
logger .info (f"Compiling { self .algorithm_config_response .name } from scratch..." )
170
- self .model = pipeline ("image-classification" , model = "google/vit-base-patch16-224" )
195
+ self .model = pipeline ("image-classification" , model = self . model_name , device = get_best_device () )
171
196
SAVED_MODELS [saved_models_key ] = self .model
172
197
else :
173
198
logger .info (f"Using saved model for { self .algorithm_config_response .name } ..." )
@@ -216,15 +241,55 @@ def run(self, detections: list[Detection]) -> list[Detection]:
216
241
217
242
return detections_to_return
218
243
219
- algorithm_config_response = AlgorithmConfigResponse (
220
- name = "HF Image Classifier" ,
221
- key = "hf-image-classifier" ,
222
- task_type = "classification" ,
223
- description = "HF ViT for image classification." ,
224
- version = 1 ,
225
- version_name = "v1" ,
226
- category_map = None ,
227
- )
244
+ def get_category_map (self ) -> AlgorithmCategoryMapResponse :
245
+ """
246
+ Extract the category map from the model.
247
+ Returns an AlgorithmCategoryMapResponse with labels, data, and model information.
248
+ """
249
+ from transformers .models .auto .configuration_auto import AutoConfig
250
+
251
+ logger .info (f"Loading configuration for { self .model_name } " )
252
+ config = AutoConfig .from_pretrained (self .model_name )
253
+
254
+ # Extract label information
255
+ if not hasattr (config , "id2label" ) or not config .id2label :
256
+ raise ValueError (
257
+ f"Cannot create category map for model { self .model_name } , no id2label mapping found in config"
258
+ )
259
+ else :
260
+ # Sort labels by index
261
+ # Ensure keys are strings for consistent access
262
+ id2label : dict [str , str ] = {str (k ): v for k , v in config .id2label .items ()}
263
+ indices = sorted ([int (k ) for k in id2label .keys ()])
264
+
265
+ # Create labels and data
266
+ labels = [id2label [str (i )] for i in indices ]
267
+ data = [{"label" : label , "index" : idx } for idx , label in zip (indices , labels )]
268
+
269
+ # Build description
270
+ description_text = (
271
+ f"Vision Transformer model trained on ImageNet-1k. "
272
+ f"Contains { len (labels )} object classes. Model: { self .model_name } "
273
+ )
274
+
275
+ return AlgorithmCategoryMapResponse (
276
+ data = data ,
277
+ labels = labels ,
278
+ version = "ImageNet-1k" ,
279
+ description = description_text ,
280
+ uri = f"https://huggingface.co/{ self .model_name } " ,
281
+ )
282
+
283
+ def get_algorithm_config_response (self ) -> AlgorithmConfigResponse :
284
+ return AlgorithmConfigResponse (
285
+ name = "HF Image Classifier" ,
286
+ key = "hf-image-classifier" ,
287
+ task_type = "classification" ,
288
+ description = "HF ViT for image classification." ,
289
+ version = 1 ,
290
+ version_name = "v1" ,
291
+ category_map = self .get_category_map (),
292
+ )
228
293
229
294
230
295
class RandomSpeciesClassifier (Algorithm ):
0 commit comments