Skip to content

Commit 41ef9e1

Browse files
committed
feat: docs and convenience methods to explain category map schema
1 parent 42a9ebf commit 41ef9e1

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

ami/ml/models/algorithm.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,41 @@
2020
class AlgorithmCategoryMap(BaseModel):
2121
"""
2222
A list of classification labels for a given algorithm version
23+
24+
Expected schema for `data` field. This is the primary "category map" used by the model
25+
to map from the category index in the model output to a human-readable label and other metadata.
26+
27+
IMPORTANT: Currently only `label` & `taxon_rank` are imported to the Taxon model if the taxon does
28+
not already exist in the Antenna database. But the Taxon model can store any metadata, so this is
29+
extensible in the future.
30+
[
31+
{
32+
"index": 0,
33+
"gbif_key": 123456,
34+
"label": "Vanessa atalanta",
35+
"taxon_rank": "SPECIES",
36+
},
37+
{
38+
"index": 1,
39+
"gbif_key": 789012,
40+
"label": "Limenitis",
41+
"taxon_rank": "GENUS",
42+
},
43+
{
44+
"id": 3,
45+
"gbif_key": 345678,
46+
"label": "Nymphalis californica",
47+
"taxon_rank": "SPECIES",
48+
}
49+
]
50+
51+
The labels field is a simple list of string labels the correct index order used by the model.
52+
[
53+
"Vanessa atalanta",
54+
"Limenitis",
55+
"Nymphalis californica",
56+
]
57+
2358
"""
2459

2560
data = models.JSONField(
@@ -55,6 +90,14 @@ def make_labels_hash(cls, labels):
5590
"""
5691
return hash("".join(labels))
5792

93+
@classmethod
94+
def labels_from_data(cls, data, label_field="label"):
95+
return [category[label_field] for category in data]
96+
97+
@classmethod
98+
def data_from_labels(cls, labels, label_field="label"):
99+
return [{"index": i, label_field: label} for i, label in enumerate(labels)]
100+
58101
def get_category(self, label, label_field="label"):
59102
# Can use JSON containment operators
60103
return self.data.index(next(category for category in self.data if category[label_field] == label))

ami/ml/tests.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,12 +675,12 @@ def test_labels_hash_auto_generation(self):
675675
from ami.ml.models import AlgorithmCategoryMap
676676

677677
# Test data
678-
test_labels = ["coleoptera", "diptera", "lepidoptera"]
679678
test_data = [
680679
{"index": 0, "label": "coleoptera"},
681680
{"index": 1, "label": "diptera"},
682681
{"index": 2, "label": "lepidoptera"},
683682
]
683+
test_labels = AlgorithmCategoryMap.labels_from_data(test_data)
684684

685685
# Create instance using objects.create()
686686
category_map = AlgorithmCategoryMap.objects.create(labels=test_labels, data=test_data, version="test-v1")
@@ -696,3 +696,22 @@ def test_labels_hash_auto_generation(self):
696696
category_map2 = AlgorithmCategoryMap.objects.create(labels=test_labels, data=test_data, version="test-v2")
697697

698698
self.assertEqual(category_map.labels_hash, category_map2.labels_hash)
699+
700+
def test_labels_data_conversion_methods(self):
701+
from ami.ml.models import AlgorithmCategoryMap
702+
703+
# Test data
704+
test_data = [
705+
{"index": 0, "label": "coleoptera"},
706+
{"index": 1, "label": "diptera"},
707+
{"index": 2, "label": "lepidoptera"},
708+
]
709+
test_labels = AlgorithmCategoryMap.labels_from_data(test_data)
710+
711+
# Convert labels to data and back
712+
converted_data = AlgorithmCategoryMap.data_from_labels(test_labels)
713+
converted_labels = AlgorithmCategoryMap.labels_from_data(converted_data)
714+
715+
# Verify conversions are correct
716+
self.assertEqual(test_data, converted_data)
717+
self.assertEqual(test_labels, converted_labels)

0 commit comments

Comments
 (0)