Skip to content

Commit 1f7b8f1

Browse files
committed
Added frameinfo generator; added cellpose; changes to scan_pipeline
1 parent 85bfe2c commit 1f7b8f1

File tree

8 files changed

+356
-23
lines changed

8 files changed

+356
-23
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from loguru import logger
2+
3+
import numpy as np
4+
5+
import torch
6+
from cellpose import models
7+
8+
from csi_analysis.pipelines.scan_pipeline import MaskType, TileSegmenter
9+
from csi_images.csi_scans import Scan
10+
from csi_images.csi_images import make_rgb
11+
12+
13+
class CellposeSegmenter(TileSegmenter):
14+
MASK_TYPE = MaskType.EVENT
15+
16+
def __init__(
17+
self,
18+
scan: Scan,
19+
model_path: str = None,
20+
use_gpu: bool = False,
21+
save: bool = False,
22+
):
23+
self.scan = scan
24+
# Preset: RGBW of AF555, AF647, DAPI, AF488
25+
self.colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)]
26+
channels = ["AF555", "AF647", "DAPI", "AF488"]
27+
self.frame_order = scan.get_channel_indices(channels)
28+
self.model_path = model_path
29+
if self.model_path is None:
30+
# Use the built-in model
31+
self.model_path = "cyto3"
32+
self.use_gpu = use_gpu
33+
self.save = save
34+
if self.use_gpu:
35+
# Check if GPU is available
36+
if not torch.cuda.is_available():
37+
logger.warning("GPU requested but not available; using CPU")
38+
self.model = models.CellposeModel(pretrained_model=self.model_path)
39+
else:
40+
self.model = models.CellposeModel(
41+
pretrained_model=self.model_path,
42+
device=torch.device("cuda"),
43+
)
44+
else:
45+
self.model = models.CellposeModel(pretrained_model=self.model_path)
46+
47+
def __repr__(self):
48+
return f"{self.__class__.__name__}-{self.model_path})"
49+
50+
def segment(self, images: list[np.ndarray]) -> dict[MaskType, np.ndarray]:
51+
ordered_frames = [images[i] for i in self.frame_order]
52+
rgb_image = make_rgb(ordered_frames, self.colors)
53+
mask, _, _ = self.model.eval(rgb_image, diameter=15, channels=[0, 0])
54+
return {self.MASK_TYPE: mask}
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import os
2+
import warnings
3+
4+
import numpy as np
5+
import pandas as pd
6+
from csi_images.csi_scans import Scan
7+
from csi_images.csi_tiles import Tile
8+
from csi_images.csi_frames import Frame
9+
from csi_images.csi_events import EventArray
10+
11+
12+
class OCULARFrameInfoGenerator:
13+
"""
14+
A pared-down "feature extractor" that extracts the quality of the
15+
DAPI and CD45 channels for each event found in a tile.
16+
"""
17+
18+
def __init__(self, scan: Scan, threshold=0.05, save: bool = False):
19+
self.scan = scan
20+
self.dapi_idx, self.cy5_idx = scan.get_channel_indices(["DAPI", "AF647"])
21+
self.threshold = threshold
22+
self.save = save
23+
24+
def __repr__(self):
25+
return f"{self.__class__.__name__}-{self.threshold})"
26+
27+
def extract_event_quality(self, x, y, dapi, cy5) -> tuple[float, float]:
28+
"""
29+
Finds the quality of the DAPI and "CY5" channels for a single event.
30+
:param x: x-coordinate of the event
31+
:param y: y-coordinate of the event
32+
:param dapi: DAPI channel image
33+
:param cy5: CD45 channel image
34+
:return: a tuple of the DAPI and CD45 quality
35+
"""
36+
# DAPI gradient at same x
37+
dapi_x = dapi[:, x] / 65535
38+
dapi_x[dapi_x < self.threshold] = 0
39+
dapi_x = np.abs(np.diff(dapi_x))
40+
dapi_x = dapi_x[dapi_x > 0]
41+
if len(dapi_x) == 0:
42+
dapi_x = 0
43+
else:
44+
dapi_x = np.mean(dapi_x) * 100
45+
# DAPI gradient at same y
46+
dapi_y = dapi[y, :] / 65535
47+
dapi_y[dapi_y < self.threshold] = 0
48+
dapi_y = np.abs(np.diff(dapi_y))
49+
dapi_y = dapi_y[dapi_y > 0]
50+
if len(dapi_y) == 0:
51+
dapi_y = 0
52+
else:
53+
dapi_y = np.mean(dapi_y) * 100
54+
# Average DAPI quality
55+
dapi = np.mean([dapi_x, dapi_y])
56+
# CD45 gradient at same x
57+
cy5_x = cy5[:, x] / 65535
58+
cy5_x[cy5_x < self.threshold] = 0
59+
cy5_x = np.abs(np.diff(cy5_x))
60+
cy5_x = cy5_x[cy5_x > 0]
61+
if len(cy5_x) == 0:
62+
cy5_x = 0
63+
else:
64+
cy5_x = np.mean(cy5_x) * 100
65+
# CD45 gradient at same y
66+
cy5_y = cy5[y, :] / 65535
67+
cy5_y[cy5_y < self.threshold] = 0
68+
cy5_y = np.abs(np.diff(cy5_y))
69+
cy5_y = cy5_y[cy5_y > 0]
70+
if len(cy5_y) == 0:
71+
cy5_y = 0
72+
else:
73+
cy5_y = np.mean(cy5_y) * 100
74+
# Average CD45 quality
75+
cy5 = np.mean([cy5_x, cy5_y])
76+
return np.float16(dapi), np.float16(cy5)
77+
78+
def extract_tile_quality(
79+
self, events: EventArray, images: list[np.ndarray] = None
80+
) -> EventArray:
81+
"""
82+
Finds the quality of the DAPI and "CY5" channels for each event in a tile
83+
:param events: EventArray for one tile
84+
:param images: list of numpy arrays representing each channel; will load if None
85+
:return: an EventArray with "dapi_quality" and "cy5_quality" in metadata
86+
"""
87+
# Copy to avoid modifying the original EventArray
88+
events = events.copy()
89+
events.metadata["dapi_quality"] = np.zeros(len(events), dtype=np.float16)
90+
events.metadata["cy5_quality"] = np.zeros(len(events), dtype=np.float16)
91+
# Populate images if needed
92+
if images is not None:
93+
dapi = images[self.dapi_idx]
94+
cy5 = images[self.cy5_idx]
95+
else:
96+
tile = Tile(self.scan, events.info["tile"][0])
97+
dapi = Frame.get_frames(tile, (self.dapi_idx,))[0].get_image()
98+
cy5 = Frame.get_frames(tile, (self.cy5_idx,))[0].get_image()
99+
# Loop through events, populating metadata
100+
for i in range(len(events)):
101+
# Determine the (x, y) coordinates of the event
102+
x = events.info["x"][i]
103+
y = events.info["y"][i]
104+
dapi_quality, cy5_quality = self.extract_event_quality(x, y, dapi, cy5)
105+
events.metadata.loc[i, "dapi_quality"] = dapi_quality
106+
events.metadata.loc[i, "cy5_quality"] = cy5_quality
107+
return events
108+
109+
def extract_scan_quality(
110+
self, events: EventArray, images: list[list[np.ndarray]] = None
111+
) -> EventArray:
112+
"""
113+
Contrary to normal feature extractors, this extracts image quality as metadata
114+
:param events: EventArray for full scan
115+
:param images: list of numpy arrays representing each channel; will load if None
116+
:return: an EventArray with "dapi_quality" and "cy5_quality" in metadata
117+
"""
118+
# Copy to avoid modifying the original EventArray
119+
events = events.copy()
120+
events.metadata["dapi_quality"] = np.zeros(len(events), dtype=np.float16)
121+
events.metadata["cy5_quality"] = np.zeros(len(events), dtype=np.float16)
122+
# Loop through each tile
123+
for i in range(len(self.scan.roi[0].tile_rows * self.scan.roi[0].tile_cols)):
124+
rows = events.info["tile"] == i
125+
# Skip if no relevant events
126+
if sum(rows) == 0:
127+
continue
128+
# Load in images
129+
if images is not None:
130+
dapi = images[i][self.dapi_idx]
131+
cy5 = images[i][self.cy5_idx]
132+
else:
133+
tile = Tile(self.scan, i)
134+
tile_images = [None] * len(self.scan.channels)
135+
tile_images[self.dapi_idx] = Frame.get_frames(tile, (self.dapi_idx,))[0]
136+
tile_images[self.dapi_idx] = tile_images[self.dapi_idx].get_image()
137+
tile_images[self.cy5_idx] = Frame.get_frames(tile, (self.cy5_idx,))[0]
138+
tile_images[self.cy5_idx] = tile_images[self.cy5_idx].get_image()
139+
# Get quality
140+
tile_events = events.rows(rows)
141+
tile_events = self.extract_tile_quality(tile_events, tile_images)
142+
# Move it into the main EventArray
143+
events.metadata.loc[rows, "dapi_quality"] = tile_events.metadata[
144+
"dapi_quality"
145+
]
146+
events.metadata.loc[rows, "cy5_quality"] = tile_events.metadata[
147+
"cy5_quality"
148+
]
149+
return events
150+
151+
def save_frameinfo_csv(self, output_path: str, events: EventArray) -> bool:
152+
"""
153+
Save the frame info to frameinfo.csv, in the OCULAR format.
154+
:param output_path: Folder to save frameinfo.csv in
155+
:param events: EventArray for the whole scan with
156+
"dapi_quality" and "cy5_quality" in metadata
157+
"""
158+
# Reorganize the required metadata in frameinfo.csv
159+
# Generate a dataframe with one row for each tile
160+
df = pd.DataFrame(
161+
{
162+
"frame_id": list(
163+
range(self.scan.roi[0].tile_rows * self.scan.roi[0].tile_cols)
164+
)
165+
}
166+
)
167+
# Gather the number of events in each tile
168+
df["cell_count"] = [sum(events.info["tile"] == i) for i in df["frame_id"]]
169+
# Gather the average quality of DAPI and CD45 in each tile, ignoring
170+
# empty slice warnings (some frames may have no events)
171+
with warnings.catch_warnings():
172+
warnings.simplefilter("ignore", category=RuntimeWarning)
173+
df["dapi_quality"] = [
174+
np.mean(events.metadata["dapi_quality"][events.info["tile"] == i])
175+
for i in df["frame_id"]
176+
]
177+
df["cy5_quality"] = [
178+
np.mean(events.metadata["cy5_quality"][events.info["tile"] == i])
179+
for i in df["frame_id"]
180+
]
181+
df = df.fillna(0)
182+
# Find the average
183+
df["avg_quality"] = (df["dapi_quality"] + df["cy5_quality"]) / 2
184+
# Round to 3 decimal places
185+
df = df.round(3)
186+
# Classically 1-indexed because of R
187+
df["frame_id"] += 1
188+
df.to_csv(os.path.join(output_path, "frameinfo.csv"))

csi_analysis/modules/ocular_report_clusterer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ class OcularReportClusterer(EventClassifier):
2222
- cluster_id: the cluster id for each event
2323
"""
2424

25-
COLUMN_NAME = "cluster_id"
26-
2725
def __init__(
2826
self,
2927
columns: list[str] = None,
28+
column_name: str = "cluster_id",
3029
max_cluster_size: int = 20,
3130
sort_by: str = None,
3231
ascending: bool = True,
3332
copy: bool = False,
3433
save: bool = False,
3534
):
3635
self.columns = columns
36+
self.column_name = column_name
3737
self.sort_by = sort_by
3838
self.max_cluster_size = max_cluster_size
3939
self.ascending = ascending
@@ -47,10 +47,10 @@ def classify_events(self, events: EventArray) -> EventArray:
4747
if self.copy:
4848
events = events.copy()
4949
cluster_labels = self.cluster_with_max_size(events)
50-
events.metadata["cluster_id"] = cluster_labels
50+
events.metadata[self.column_name] = cluster_labels
5151
if self.sort_by is not None:
5252
cluster_mapping = self.sort_clusters(events)
53-
events.metadata["cluster_id"] = events.metadata["cluster_id"].map(
53+
events.metadata[self.column_name] = events.metadata[self.column_name].map(
5454
cluster_mapping
5555
)
5656
return events
@@ -124,9 +124,9 @@ def sort_clusters(self, events):
124124
# Get the average "interesting" p-value for each cluster
125125
sort_data = events.get(self.sort_by)
126126
cluster_means = []
127-
cluster_ids = pd.unique(events.metadata["cluster_id"])
127+
cluster_ids = pd.unique(events.metadata[self.column_name])
128128
for cluster_id in cluster_ids:
129-
cluster_indices = events.metadata["cluster_id"] == cluster_id
129+
cluster_indices = events.metadata[self.column_name] == cluster_id
130130
cluster_means.append(sort_data.loc[cluster_indices].mean()[0])
131131

132132
# Sort the cluster_ids by their average p-values, descending (reverse=True)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os.path
2+
3+
import numpy as np
4+
5+
from csi_images.csi_events import EventArray
6+
from csi_images.csi_scans import Scan
7+
8+
from ..pipelines.scan_pipeline import ReportGenerator
9+
10+
11+
class OCULARReportGenerator(ReportGenerator):
12+
13+
def __init__(self, scan: Scan, save: bool = False):
14+
self.scan = scan
15+
self.save = save
16+
17+
def make_report(
18+
self,
19+
output_path: str,
20+
events: EventArray,
21+
images: list[list[np.ndarray]] = None,
22+
) -> bool:
23+
# Create dummy files
24+
for file in ["out.rds", "cc-final.csv", "others-final.csv"]:
25+
open(file, "a").close()
26+
# Confirm all files were created
27+
success = True
28+
for file in ["out.rds", "cc-final.csv", "others-final.csv"]:
29+
success = success and os.path.exists(file)
30+
return success
31+
32+
def __repr__(self):
33+
return f"{self.__class__.__name__})"
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Module for tophat preprocessing of frame images.
3+
"""
4+
5+
import numpy as np
6+
import cv2
7+
8+
from csi_images.csi_scans import Scan
9+
from ..pipelines.scan_pipeline import TilePreprocessor
10+
11+
12+
class TophatPreprocessor(TilePreprocessor):
13+
def __init__(
14+
self,
15+
scan: Scan,
16+
channels: list[int | str],
17+
tophat_size: int = 0,
18+
):
19+
"""
20+
Preprocess frame images with tophat filtering.
21+
:param scan:
22+
:param channels:
23+
:param tophat_size:
24+
"""
25+
self.scan = scan
26+
self.channels = channels
27+
if isinstance(channels[0], str):
28+
self.channels = scan.get_channel_indices(channels)
29+
self.tophat_size = tophat_size
30+
31+
def preprocess(self, images: list[np.ndarray]) -> list[np.ndarray]:
32+
if self.tophat_size == 0:
33+
return images
34+
35+
tophat_kernel = cv2.getStructuringElement(
36+
cv2.MORPH_ELLIPSE, (self.tophat_size, self.tophat_size)
37+
)
38+
39+
for i in self.channels:
40+
images[i] = cv2.morphologyEx(images[i], cv2.MORPH_TOPHAT, tophat_kernel)
41+
42+
def __repr__(self):
43+
return f"{self.__class__.__name__}-{self.tophat_size})"

0 commit comments

Comments
 (0)