|
| 1 | +import os |
| 2 | +import warnings |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import pandas as pd |
| 6 | +from csi_images.csi_scans import Scan |
| 7 | +from csi_images.csi_tiles import Tile |
| 8 | +from csi_images.csi_frames import Frame |
| 9 | +from csi_images.csi_events import EventArray |
| 10 | + |
| 11 | + |
| 12 | +class OCULARFrameInfoGenerator: |
| 13 | + """ |
| 14 | + A pared-down "feature extractor" that extracts the quality of the |
| 15 | + DAPI and CD45 channels for each event found in a tile. |
| 16 | + """ |
| 17 | + |
| 18 | + def __init__(self, scan: Scan, threshold=0.05, save: bool = False): |
| 19 | + self.scan = scan |
| 20 | + self.dapi_idx, self.cy5_idx = scan.get_channel_indices(["DAPI", "AF647"]) |
| 21 | + self.threshold = threshold |
| 22 | + self.save = save |
| 23 | + |
| 24 | + def __repr__(self): |
| 25 | + return f"{self.__class__.__name__}-{self.threshold})" |
| 26 | + |
| 27 | + def extract_event_quality(self, x, y, dapi, cy5) -> tuple[float, float]: |
| 28 | + """ |
| 29 | + Finds the quality of the DAPI and "CY5" channels for a single event. |
| 30 | + :param x: x-coordinate of the event |
| 31 | + :param y: y-coordinate of the event |
| 32 | + :param dapi: DAPI channel image |
| 33 | + :param cy5: CD45 channel image |
| 34 | + :return: a tuple of the DAPI and CD45 quality |
| 35 | + """ |
| 36 | + # DAPI gradient at same x |
| 37 | + dapi_x = dapi[:, x] / 65535 |
| 38 | + dapi_x[dapi_x < self.threshold] = 0 |
| 39 | + dapi_x = np.abs(np.diff(dapi_x)) |
| 40 | + dapi_x = dapi_x[dapi_x > 0] |
| 41 | + if len(dapi_x) == 0: |
| 42 | + dapi_x = 0 |
| 43 | + else: |
| 44 | + dapi_x = np.mean(dapi_x) * 100 |
| 45 | + # DAPI gradient at same y |
| 46 | + dapi_y = dapi[y, :] / 65535 |
| 47 | + dapi_y[dapi_y < self.threshold] = 0 |
| 48 | + dapi_y = np.abs(np.diff(dapi_y)) |
| 49 | + dapi_y = dapi_y[dapi_y > 0] |
| 50 | + if len(dapi_y) == 0: |
| 51 | + dapi_y = 0 |
| 52 | + else: |
| 53 | + dapi_y = np.mean(dapi_y) * 100 |
| 54 | + # Average DAPI quality |
| 55 | + dapi = np.mean([dapi_x, dapi_y]) |
| 56 | + # CD45 gradient at same x |
| 57 | + cy5_x = cy5[:, x] / 65535 |
| 58 | + cy5_x[cy5_x < self.threshold] = 0 |
| 59 | + cy5_x = np.abs(np.diff(cy5_x)) |
| 60 | + cy5_x = cy5_x[cy5_x > 0] |
| 61 | + if len(cy5_x) == 0: |
| 62 | + cy5_x = 0 |
| 63 | + else: |
| 64 | + cy5_x = np.mean(cy5_x) * 100 |
| 65 | + # CD45 gradient at same y |
| 66 | + cy5_y = cy5[y, :] / 65535 |
| 67 | + cy5_y[cy5_y < self.threshold] = 0 |
| 68 | + cy5_y = np.abs(np.diff(cy5_y)) |
| 69 | + cy5_y = cy5_y[cy5_y > 0] |
| 70 | + if len(cy5_y) == 0: |
| 71 | + cy5_y = 0 |
| 72 | + else: |
| 73 | + cy5_y = np.mean(cy5_y) * 100 |
| 74 | + # Average CD45 quality |
| 75 | + cy5 = np.mean([cy5_x, cy5_y]) |
| 76 | + return np.float16(dapi), np.float16(cy5) |
| 77 | + |
| 78 | + def extract_tile_quality( |
| 79 | + self, events: EventArray, images: list[np.ndarray] = None |
| 80 | + ) -> EventArray: |
| 81 | + """ |
| 82 | + Finds the quality of the DAPI and "CY5" channels for each event in a tile |
| 83 | + :param events: EventArray for one tile |
| 84 | + :param images: list of numpy arrays representing each channel; will load if None |
| 85 | + :return: an EventArray with "dapi_quality" and "cy5_quality" in metadata |
| 86 | + """ |
| 87 | + # Copy to avoid modifying the original EventArray |
| 88 | + events = events.copy() |
| 89 | + events.metadata["dapi_quality"] = np.zeros(len(events), dtype=np.float16) |
| 90 | + events.metadata["cy5_quality"] = np.zeros(len(events), dtype=np.float16) |
| 91 | + # Populate images if needed |
| 92 | + if images is not None: |
| 93 | + dapi = images[self.dapi_idx] |
| 94 | + cy5 = images[self.cy5_idx] |
| 95 | + else: |
| 96 | + tile = Tile(self.scan, events.info["tile"][0]) |
| 97 | + dapi = Frame.get_frames(tile, (self.dapi_idx,))[0].get_image() |
| 98 | + cy5 = Frame.get_frames(tile, (self.cy5_idx,))[0].get_image() |
| 99 | + # Loop through events, populating metadata |
| 100 | + for i in range(len(events)): |
| 101 | + # Determine the (x, y) coordinates of the event |
| 102 | + x = events.info["x"][i] |
| 103 | + y = events.info["y"][i] |
| 104 | + dapi_quality, cy5_quality = self.extract_event_quality(x, y, dapi, cy5) |
| 105 | + events.metadata.loc[i, "dapi_quality"] = dapi_quality |
| 106 | + events.metadata.loc[i, "cy5_quality"] = cy5_quality |
| 107 | + return events |
| 108 | + |
| 109 | + def extract_scan_quality( |
| 110 | + self, events: EventArray, images: list[list[np.ndarray]] = None |
| 111 | + ) -> EventArray: |
| 112 | + """ |
| 113 | + Contrary to normal feature extractors, this extracts image quality as metadata |
| 114 | + :param events: EventArray for full scan |
| 115 | + :param images: list of numpy arrays representing each channel; will load if None |
| 116 | + :return: an EventArray with "dapi_quality" and "cy5_quality" in metadata |
| 117 | + """ |
| 118 | + # Copy to avoid modifying the original EventArray |
| 119 | + events = events.copy() |
| 120 | + events.metadata["dapi_quality"] = np.zeros(len(events), dtype=np.float16) |
| 121 | + events.metadata["cy5_quality"] = np.zeros(len(events), dtype=np.float16) |
| 122 | + # Loop through each tile |
| 123 | + for i in range(len(self.scan.roi[0].tile_rows * self.scan.roi[0].tile_cols)): |
| 124 | + rows = events.info["tile"] == i |
| 125 | + # Skip if no relevant events |
| 126 | + if sum(rows) == 0: |
| 127 | + continue |
| 128 | + # Load in images |
| 129 | + if images is not None: |
| 130 | + dapi = images[i][self.dapi_idx] |
| 131 | + cy5 = images[i][self.cy5_idx] |
| 132 | + else: |
| 133 | + tile = Tile(self.scan, i) |
| 134 | + tile_images = [None] * len(self.scan.channels) |
| 135 | + tile_images[self.dapi_idx] = Frame.get_frames(tile, (self.dapi_idx,))[0] |
| 136 | + tile_images[self.dapi_idx] = tile_images[self.dapi_idx].get_image() |
| 137 | + tile_images[self.cy5_idx] = Frame.get_frames(tile, (self.cy5_idx,))[0] |
| 138 | + tile_images[self.cy5_idx] = tile_images[self.cy5_idx].get_image() |
| 139 | + # Get quality |
| 140 | + tile_events = events.rows(rows) |
| 141 | + tile_events = self.extract_tile_quality(tile_events, tile_images) |
| 142 | + # Move it into the main EventArray |
| 143 | + events.metadata.loc[rows, "dapi_quality"] = tile_events.metadata[ |
| 144 | + "dapi_quality" |
| 145 | + ] |
| 146 | + events.metadata.loc[rows, "cy5_quality"] = tile_events.metadata[ |
| 147 | + "cy5_quality" |
| 148 | + ] |
| 149 | + return events |
| 150 | + |
| 151 | + def save_frameinfo_csv(self, output_path: str, events: EventArray) -> bool: |
| 152 | + """ |
| 153 | + Save the frame info to frameinfo.csv, in the OCULAR format. |
| 154 | + :param output_path: Folder to save frameinfo.csv in |
| 155 | + :param events: EventArray for the whole scan with |
| 156 | + "dapi_quality" and "cy5_quality" in metadata |
| 157 | + """ |
| 158 | + # Reorganize the required metadata in frameinfo.csv |
| 159 | + # Generate a dataframe with one row for each tile |
| 160 | + df = pd.DataFrame( |
| 161 | + { |
| 162 | + "frame_id": list( |
| 163 | + range(self.scan.roi[0].tile_rows * self.scan.roi[0].tile_cols) |
| 164 | + ) |
| 165 | + } |
| 166 | + ) |
| 167 | + # Gather the number of events in each tile |
| 168 | + df["cell_count"] = [sum(events.info["tile"] == i) for i in df["frame_id"]] |
| 169 | + # Gather the average quality of DAPI and CD45 in each tile, ignoring |
| 170 | + # empty slice warnings (some frames may have no events) |
| 171 | + with warnings.catch_warnings(): |
| 172 | + warnings.simplefilter("ignore", category=RuntimeWarning) |
| 173 | + df["dapi_quality"] = [ |
| 174 | + np.mean(events.metadata["dapi_quality"][events.info["tile"] == i]) |
| 175 | + for i in df["frame_id"] |
| 176 | + ] |
| 177 | + df["cy5_quality"] = [ |
| 178 | + np.mean(events.metadata["cy5_quality"][events.info["tile"] == i]) |
| 179 | + for i in df["frame_id"] |
| 180 | + ] |
| 181 | + df = df.fillna(0) |
| 182 | + # Find the average |
| 183 | + df["avg_quality"] = (df["dapi_quality"] + df["cy5_quality"]) / 2 |
| 184 | + # Round to 3 decimal places |
| 185 | + df = df.round(3) |
| 186 | + # Classically 1-indexed because of R |
| 187 | + df["frame_id"] += 1 |
| 188 | + df.to_csv(os.path.join(output_path, "frameinfo.csv")) |
0 commit comments