-
Notifications
You must be signed in to change notification settings - Fork 12
DynaCell Metrics #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
DynaCell Metrics #242
Conversation
…. one for pred one for target
if self.dtype is not None: | ||
_pred = _pred.astype(self.dtype) | ||
_target = _target.astype(self.dtype) | ||
pred = torch.from_numpy(_pred.astype(self.dtype)) | ||
target = torch.from_numpy(_target.astype(self.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do this twice?
applications/DynaCell/demo_script.py
Outdated
@@ -0,0 +1,178 @@ | |||
""" | |||
This script is a demo script for the DynaCell application. | |||
It loads the ome-zarr 0.4v format, calculates metrics and saves the results as csv files |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It loads the ome-zarr 0.4v format, calculates metrics and saves the results as csv files | |
It loads the ome-zarr v0.4 format, calculates metrics and saves the results as csv files |
applications/DynaCell/demo_script.py
Outdated
|
||
csv_database_path = Path( | ||
"/home/eduardo.hirata/repos/viscy/applications/DynaCell/dynacell_summary_table.csv" | ||
).expanduser() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Path.home() / "rel/path"
is likely what you want.
viscy/data/dynacell.py
Outdated
return sample | ||
|
||
|
||
class DynaCellDataBase: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capital XxxBase
reads like the name of a base class. Also this is not a databse with a runtime, but a dataframe normalizer.
viscy/data/dynacell.py
Outdated
|
||
# Extract zarr store paths | ||
self._filtered_db["Zarr path"] = self._filtered_db["Path"].apply( | ||
lambda x: Path(*Path(x).parts[:-3]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lambda x: Path(*Path(x).parts[:-3]) | |
lambda x: x.parent.parent.parent |
"zarr_path": self.zarr_paths[idx], | ||
"cell_type": self.cell_types_per_store[idx], | ||
"organelle": self.organelles_per_store[idx], | ||
"infection_condition": self.infection_per_store[idx], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these first converted to lists?
pred_data = self.pred_database[i] | ||
|
||
# Ensure target and prediction metadata match | ||
self._validate_matching_metadata(target_data, pred_data, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How long is this loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now we don't parallelize. Are you thinking on spitting it as batches?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to understand the size of the loop. If it's long-running then vectorizing with pandas operations could be helpful.
# Check cell type | ||
if target_data["cell_type"] != pred_data["cell_type"]: | ||
raise ValueError( | ||
f"Cell type mismatch at index {idx}: " | ||
f"target={target_data['cell_type']}, pred={pred_data['cell_type']}" | ||
) | ||
|
||
# Check organelle | ||
if target_data["organelle"] != pred_data["organelle"]: | ||
raise ValueError( | ||
f"Organelle mismatch at index {idx}: " | ||
f"target={target_data['organelle']}, pred={pred_data['organelle']}" | ||
) | ||
|
||
# Check infection condition | ||
if target_data["infection_condition"] != pred_data["infection_condition"]: | ||
raise ValueError( | ||
f"Infection condition mismatch at index {idx}: " | ||
f"target={target_data['infection_condition']}, pred={pred_data['infection_condition']}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a loop over the string keys.
A PyTorch Dataset providing paired target and prediction images / volumes from OME-Zarr | ||
datasets. | ||
Attributes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use numpy-style docstring: #133
|
||
cell_type: str | ||
organelle: str | ||
infection: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This key is not the same as the CSV column name?
Batch size for processing, by default 1 | ||
num_workers : int, optional | ||
Number of workers for data loading, by default 0 | ||
version : str, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What version is this?
target_database: Path, | ||
pred_database: Path, | ||
output_dir: Path, | ||
method: str = "intensity", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be all supported literals rather than a free string? Then lightning CLI can check the input type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we will update this once we have all the possible ones. We might make it it's own list for the organelles. thanks
# Add metadata if available | ||
if "cell_type" in batch: | ||
metrics_dict["cell_type"] = batch["cell_type"] | ||
if "organelle" in batch: | ||
metrics_dict["organelle"] = batch["organelle"] | ||
if "infection_condition" in batch: | ||
metrics_dict["infection_condition"] = batch["infection_condition"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if "infection_condition" in batch: | ||
metrics_dict["infection_condition"] = batch["infection_condition"] | ||
|
||
self.logger.log_metrics(metrics_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is repeated in 2D and 3D. Should return the metrics (as the 'compute' in the method name would suggests) and let the calling function log once.
|
||
def _validate_metrics(self): | ||
"""Validate the metrics parameter.""" | ||
valid_metrics = ["mae", "mse", "ssim", "ms_ssim", "pearson", "cosine"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be documented in the signature of __init__
, i.e. metrics: list[Literal["mae", "mse", "ssim", "ms_ssim", "pearson", "cosine"]]
.
if metric not in valid_metrics: | ||
raise ValueError(f"Metric '{metric}' not in {valid_metrics}") | ||
|
||
def test_step(self, batch, batch_idx: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the type of batch
?
from torchmetrics.functional import ( | ||
cosine_similarity, | ||
mean_absolute_error, | ||
mean_squared_error, | ||
pearson_corrcoef, | ||
structural_similarity_index_measure, | ||
) | ||
|
||
from viscy.translation.evaluation_metrics import ms_ssim_25d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why import separately from other metrics?
# Add metadata if available | ||
if "cell_type" in batch: | ||
metrics_dict["cell_type"] = batch["cell_type"] | ||
if "organelle" in batch: | ||
metrics_dict["organelle"] = batch["organelle"] | ||
if "infection_condition" in batch: | ||
metrics_dict["infection_condition"] = batch["infection_condition"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) | ||
elif metric == "ms_ssim": | ||
if pred.ndim > 1: | ||
metrics_dict["ms_ssim"] = ms_ssim_25d(pred, target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is defined only for use in a loss function. (I made it up...)
elif metric == "pearson": | ||
metrics_dict["pearson"] = pearson_corrcoef( | ||
pred.flatten(), target.flatten() | ||
) | ||
elif metric == "cosine": | ||
metrics_dict["cosine"] = cosine_similarity( | ||
pred.flatten(), target.flatten() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use more precise naming?
( | ||
target.squeeze(2) | ||
if target.shape[2] == 1 | ||
else target[:, :, target.shape[2] // 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially dangerous: this could compute SSIM on a different region than the other metrics.
…e number of devives and nodes.
TODO:
|
torch.set_float32_matmul_precision("high") | ||
|
||
# Suppress Lightning warnings for intentional CPU usage | ||
os.environ["SLURM_NTASKS"] = "1" # Suppress SLURM warning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this have side effect for the current shell?
This PR adds: