Skip to content

Commit 88ffba8

Browse files
chore: rebase feat/postprocessing-class-masking onto feat/postprocessing-framework
1 parent 5e85b75 commit 88ffba8

File tree

4 files changed

+299
-21
lines changed

4 files changed

+299
-21
lines changed

ami/ml/post_processing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from . import rank_rollup, small_size_filter # noqa: F401
1+
from . import class_masking, rank_rollup, small_size_filter # noqa: F401

ami/ml/post_processing/base.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import abc
44
import logging
5-
from typing import Any, Optional
5+
from typing import Any
66

77
from ami.jobs.models import Job
88
from ami.ml.models import Algorithm
@@ -39,34 +39,54 @@ class BasePostProcessingTask(abc.ABC):
3939
Abstract base class for all post-processing tasks.
4040
"""
4141

42+
# Each task must override these
4243
key: str = ""
4344
name: str = ""
4445

4546
def __init__(
4647
self,
47-
job: Optional["Job"] = None,
48-
task_logger: logging.Logger | None = None,
48+
job: Job | None = None,
49+
logger: logging.Logger | None = None,
4950
**config: Any,
5051
):
51-
"""
52-
Initialize task with optional job and logger context.
53-
"""
5452
self.job = job
55-
self.config: dict[str, Any] = config
56-
57-
if job:
53+
self.config = config
54+
# Choose the right logger
55+
if logger is not None:
56+
self.logger = logger
57+
elif job is not None:
5858
self.logger = job.logger
59-
elif task_logger:
60-
self.logger = task_logger
6159
else:
6260
self.logger = logging.getLogger(f"ami.post_processing.{self.key}")
63-
self.log_config()
61+
62+
algorithm, _ = Algorithm.objects.get_or_create(
63+
name=self.__class__.__name__,
64+
defaults={
65+
"description": f"Post-processing task: {self.key}",
66+
"task_type": AlgorithmTaskType.POST_PROCESSING.value,
67+
},
68+
)
69+
self.algorithm: Algorithm = algorithm
70+
71+
self.logger.info(f"Initialized {self.__class__.__name__} with config={self.config}, job={job}")
72+
73+
def update_progress(self, progress: float):
74+
"""
75+
Update progress if job is present, otherwise just log.
76+
"""
77+
78+
if self.job:
79+
self.job.progress.update_stage(self.job.job_type_key, progress=progress)
80+
self.job.save(update_fields=["progress"])
81+
82+
else:
83+
# No job object — fallback to plain logging
84+
self.logger.info(f"[{self.name}] Progress {progress:.0%}")
6485

6586
@abc.abstractmethod
6687
def run(self) -> None:
67-
"""Run the task logic. Must be implemented by subclasses."""
68-
raise NotImplementedError("Subclasses must implement run()")
69-
70-
def log_config(self):
71-
"""Helper to log the task configuration at start."""
72-
self.logger.info(f"Running task {self.name} ({self.key}) with config: {self.config}")
88+
"""
89+
Run the task logic.
90+
Must be implemented by subclasses.
91+
"""
92+
raise NotImplementedError("BasePostProcessingTask subclasses must implement run()")
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import logging
2+
3+
from django.db.models import QuerySet
4+
from django.utils import timezone
5+
6+
from ami.main.models import Classification, Occurrence, SourceImageCollection, TaxaList
7+
from ami.ml.models import Algorithm, AlgorithmCategoryMap
8+
from ami.ml.post_processing.base import BasePostProcessingTask, register_postprocessing_task
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def update_single_occurrence(
14+
occurrence: Occurrence,
15+
algorithm: Algorithm,
16+
taxa_list: TaxaList,
17+
task_logger: logging.Logger = logger,
18+
):
19+
task_logger.info(f"Recalculating classifications for occurrence {occurrence.pk}.")
20+
21+
# Get the classifications for the occurrence in the collection
22+
classifications = Classification.objects.filter(
23+
detection__occurrence=occurrence,
24+
terminal=True,
25+
algorithm=algorithm,
26+
scores__isnull=False,
27+
).distinct()
28+
29+
make_classifications_filtered_by_taxa_list(
30+
classifications=classifications,
31+
taxa_list=taxa_list,
32+
algorithm=algorithm,
33+
)
34+
35+
36+
def update_occurrences_in_collection(
37+
collection: SourceImageCollection,
38+
taxa_list: TaxaList,
39+
algorithm: Algorithm,
40+
params: dict,
41+
task_logger: logging.Logger = logger,
42+
job=None,
43+
):
44+
task_logger.info(f"Recalculating classifications based on a taxa list. Params: {params}")
45+
46+
# Make new AlgorithmCategoryMap with the taxa in the list
47+
# @TODO
48+
49+
classifications = Classification.objects.filter(
50+
detection__source_image__collections=collection,
51+
terminal=True,
52+
# algorithm__task_type="classification",
53+
algorithm=algorithm,
54+
scores__isnull=False,
55+
).distinct()
56+
57+
make_classifications_filtered_by_taxa_list(
58+
classifications=classifications,
59+
taxa_list=taxa_list,
60+
algorithm=algorithm,
61+
)
62+
63+
64+
def make_classifications_filtered_by_taxa_list(
65+
classifications: QuerySet[Classification],
66+
taxa_list: TaxaList,
67+
algorithm: Algorithm,
68+
):
69+
taxa_in_list = taxa_list.taxa.all()
70+
71+
occurrences_to_update: set[Occurrence] = set()
72+
logger.info(f"Found {len(classifications)} terminal classifications with scores to update.")
73+
74+
if not classifications:
75+
raise ValueError("No terminal classifications with scores found to update.")
76+
77+
if not algorithm.category_map:
78+
raise ValueError(f"Algorithm {algorithm} does not have a category map.")
79+
category_map: AlgorithmCategoryMap = algorithm.category_map
80+
81+
# Consider moving this to a method on the Classification model
82+
83+
# @TODO find a more efficient way to get the category map with taxa. This is slow!
84+
logger.info(f"Retrieving category map with Taxa instances for algorithm {algorithm}")
85+
category_map_with_taxa = category_map.with_taxa()
86+
# Filter the category map to only include taxa that are in the taxa list
87+
# included_category_map_with_taxa = [
88+
# category for category in category_map_with_taxa if category["taxon"] in taxa_in_list
89+
# ]
90+
excluded_category_map_with_taxa = [
91+
category for category in category_map_with_taxa if category["taxon"] not in taxa_in_list
92+
]
93+
94+
# included_category_indices = [int(category["index"]) for category in category_map_with_taxa]
95+
excluded_category_indices = [
96+
int(category["index"]) for category in excluded_category_map_with_taxa # type: ignore
97+
]
98+
99+
# Log number of categories in the category map, num included, and num excluded, num classifications to update
100+
logger.info(
101+
f"Category map has {len(category_map_with_taxa)} categories, "
102+
f"{len(excluded_category_map_with_taxa)} categories excluded, "
103+
f"{len(classifications)} classifications to check"
104+
)
105+
106+
classifications_to_add = []
107+
classifications_to_update = []
108+
109+
timestamp = timezone.now()
110+
for classification in classifications:
111+
scores, logits = classification.scores, classification.logits
112+
# Set scores and logits to zero if they are not in the filtered category indices
113+
114+
import numpy as np
115+
116+
# Assert that all scores & logits are lists of numbers
117+
if not isinstance(scores, list) or not all(isinstance(score, (int, float)) for score in scores):
118+
raise ValueError(f"Scores for classification {classification.pk} are not a list of numbers: {scores}")
119+
if not isinstance(logits, list) or not all(isinstance(logit, (int, float)) for logit in logits):
120+
raise ValueError(f"Logits for classification {classification.pk} are not a list of numbers: {logits}")
121+
122+
logger.debug(f"Processing classification {classification.pk} with {len(scores)} scores")
123+
logger.info(f"Previous totals: {sum(scores)} scores, {sum(logits)} logits")
124+
125+
# scores_np_filtered = np.array(scores)
126+
logits_np = np.array(logits)
127+
128+
# scores_np_filtered[excluded_category_indices] = 0.0
129+
130+
# @TODO can we use np.NAN instead of 0.0? zero will NOT calculate correctly in softmax.
131+
# @TODO delete the excluded categories from the scores and logits instead of setting to 0.0
132+
# logits_np[excluded_category_indices] = 0.0
133+
# logits_np[excluded_category_indices] = np.nan
134+
logits_np[excluded_category_indices] = -100
135+
136+
logits: list[float] = logits_np.tolist()
137+
138+
from numpy import exp
139+
from numpy import sum as np_sum
140+
141+
# @TODO add test to see if this is correct, or needed!
142+
# Recalculate the softmax scores based on the filtered logits
143+
scores_np: np.ndarray = exp(logits_np - np.max(logits_np)) # Subtract max for numerical stability
144+
scores_np /= np_sum(scores_np) # Normalize to get probabilities
145+
146+
scores: list = scores_np.tolist() # Convert back to list
147+
148+
logger.info(f"New totals: {sum(scores)} scores, {sum(logits)} logits")
149+
150+
# Get the taxon with the highest score using the index of the max score
151+
top_index = scores.index(max(scores))
152+
top_taxon = category_map_with_taxa[top_index][
153+
"taxon"
154+
] # @TODO: This doesn't work if the taxon has never been classified
155+
print("Top taxon: ", category_map_with_taxa[top_index]) # @TODO: REMOVE
156+
print("Top index: ", top_index) # @TODO: REMOVE
157+
158+
# check if needs updating
159+
if classification.scores == scores and classification.logits == logits:
160+
logger.debug(f"Classification {classification.pk} does not need updating")
161+
continue
162+
163+
# Consider the existing classification as an intermediate classification
164+
classification.terminal = False
165+
classification.updated_at = timestamp
166+
167+
# Recalculate the top taxon and score
168+
new_classification = Classification(
169+
taxon=top_taxon,
170+
algorithm=classification.algorithm,
171+
score=max(scores),
172+
scores=scores,
173+
logits=logits,
174+
detection=classification.detection,
175+
timestamp=classification.timestamp,
176+
terminal=True,
177+
category_map=None, # @TODO need a new category map with the filtered taxa
178+
created_at=timestamp,
179+
updated_at=timestamp,
180+
)
181+
if new_classification.taxon is None:
182+
raise (ValueError("Classification isn't registered yet. Aborting")) # @TODO remove or fail gracefully
183+
184+
classifications_to_update.append(classification)
185+
classifications_to_add.append(new_classification)
186+
187+
assert new_classification.detection is not None
188+
assert new_classification.detection.occurrence is not None
189+
occurrences_to_update.add(new_classification.detection.occurrence)
190+
191+
logging.info(
192+
f"Adding new classification for Taxon {top_taxon} to occurrence {new_classification.detection.occurrence}"
193+
)
194+
195+
# Bulk update the existing classifications
196+
if classifications_to_update:
197+
logger.info(f"Bulk updating {len(classifications_to_update)} existing classifications")
198+
Classification.objects.bulk_update(classifications_to_update, ["terminal", "updated_at"])
199+
logger.info(f"Updated {len(classifications_to_update)} existing classifications")
200+
201+
if classifications_to_add:
202+
# Bulk create the new classifications
203+
logger.info(f"Bulk creating {len(classifications_to_add)} new classifications")
204+
Classification.objects.bulk_create(classifications_to_add)
205+
logger.info(f"Added {len(classifications_to_add)} new classifications")
206+
207+
# Update the occurrence determinations
208+
logger.info(f"Updating the determinations for {len(occurrences_to_update)} occurrences")
209+
for occurrence in occurrences_to_update:
210+
occurrence.save(update_determination=True)
211+
logger.info(f"Updated determinations for {len(occurrences_to_update)} occurrences")
212+
213+
214+
@register_postprocessing_task
215+
class ClassMaskingTask(BasePostProcessingTask):
216+
key = "class_masking"
217+
name = "Class masking"
218+
219+
def run(self) -> None:
220+
"""Apply class masking on a source image collection using a taxa list."""
221+
job = self.job
222+
self.logger.info(f"=== Starting {self.name} ===")
223+
224+
collection_id = self.config.get("collection_id")
225+
taxa_list_id = self.config.get("taxa_list_id")
226+
algorithm_id = self.config.get("algorithm_id")
227+
228+
# Validate config parameters
229+
if not all([collection_id, taxa_list_id, algorithm_id]):
230+
self.logger.error("Missing required configuration: collection_id, taxa_list_id, algorithm_id")
231+
return
232+
233+
try:
234+
collection = SourceImageCollection.objects.get(pk=collection_id)
235+
taxa_list = TaxaList.objects.get(pk=taxa_list_id)
236+
algorithm = Algorithm.objects.get(pk=algorithm_id)
237+
except Exception as e:
238+
self.logger.exception(f"Failed to load objects: {e}")
239+
return
240+
241+
self.logger.info(f"Applying class masking on collection {collection_id} using taxa list {taxa_list_id}")
242+
243+
update_occurrences_in_collection(
244+
collection=collection,
245+
taxa_list=taxa_list,
246+
algorithm=algorithm,
247+
params=self.config,
248+
task_logger=self.logger,
249+
job=job,
250+
)
251+
252+
self.logger.info("Class masking completed successfully.")
253+
self.logger.info(f"=== Completed {self.name} ===")

ami/ml/post_processing/small_size_filter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run(self) -> None:
3131

3232
try:
3333
collection = SourceImageCollection.objects.get(pk=collection_id)
34-
self.logger.info(f"Loaded SourceImageCollection {collection_id} " f"(Project={collection.project})")
34+
self.logger.info(f"Loaded SourceImageCollection {collection_id} (Project={collection.project})")
3535
except SourceImageCollection.DoesNotExist:
3636
msg = f"SourceImageCollection {collection_id} not found"
3737
self.logger.error(msg)
@@ -85,6 +85,11 @@ def run(self) -> None:
8585
comment=f"Auto-set by {self.name} post-processing task",
8686
)
8787
modified += 1
88-
self.logger.debug(f"Detection {det.pk}: marked as 'Not identifiable'")
88+
self.logger.info(f"Detection {det.pk}: marked as 'Not identifiable'")
89+
90+
# Update progress every 10 detections
91+
if i % 10 == 0 or i == total:
92+
progress = i / total if total > 0 else 1.0
93+
self.update_progress(progress)
8994

9095
self.logger.info(f"=== Completed {self.name}: {modified}/{total} detections modified ===")

0 commit comments

Comments
 (0)