Skip to content

Commit 76f4b98

Browse files
authored
Merge pull request #59 from janelia-cellmap/rhoadesScholar_update
2 parents 36f64a9 + acdd45c commit 76f4b98

File tree

10 files changed

+160
-61
lines changed

10 files changed

+160
-61
lines changed

cellmap_flow/cli/multiple_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main():
7171
logger.error(
7272
"Example: cellmap_flow_multiple --data-path /some/shared/path --dacapo -r run_1 -it 60 --dacapo -r run_2 -it 50 --script -s /path/to/script"
7373
)
74-
logger.error("Now we will just open the raw data ..")
74+
logger.error("Now we will just open the raw data ...")
7575

7676
# Extract data path
7777
data_path = None

cellmap_flow/dashboard/app.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,19 @@
2727
import time
2828

2929
logger = logging.getLogger(__name__)
30-
app = Flask(__name__)
30+
# Explicitly set template and static folder paths for package installation
31+
template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
32+
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
33+
app = Flask(__name__, template_folder=template_dir, static_folder=static_dir)
3134
CORS(app)
3235
NEUROGLANCER_URL = None
3336
INFERENCE_SERVER = None
34-
CustomCodeFolder = "/Users/zouinkhim/Desktop/cellmap/cellmap-flow/example/example_norm"
37+
CUSTOM_CODE_FOLDER = os.path.expanduser(
38+
os.environ.get(
39+
"CUSTOM_CODE_FOLDER",
40+
"~/Desktop/cellmap/cellmap-flow/example/example_norm",
41+
)
42+
)
3543

3644

3745
@app.route("/")
@@ -147,7 +155,7 @@ def process():
147155
# Save custom code to a file with date and time
148156
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
149157
filename = f"custom_code_{timestamp}.py"
150-
filepath = os.path.join(CustomCodeFolder, filename)
158+
filepath = os.path.join(CUSTOM_CODE_FOLDER, filename)
151159

152160
with open(filepath, "w") as file:
153161
file.write(custom_code)
Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,46 @@
11
<!DOCTYPE html>
22
<html lang="en">
3+
34
<head>
45
<meta charset="UTF-8" />
56
<title>CellMap Flow Dashboard</title>
67
<!-- Bootstrap 5 (optional, for styling) -->
7-
<link
8-
href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css"
9-
rel="stylesheet"
10-
/>
8+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet" />
119
<link rel="icon" href="{{ url_for('static', filename='img/favicon.ico') }}" type="image/x-icon" />
1210
<link rel="stylesheet" href="{{ url_for('static', filename='css/dark.css') }}" />
1311
</head>
12+
1413
<body>
1514
<nav class="navbar navbar-expand-lg navbar-dark bg-dark">
1615
<div class="container-fluid">
17-
<img
18-
src="https://raw.githubusercontent.com/janelia-cellmap/cellmap-flow/refs/heads/main/img/CMFLOW_dark.png"
19-
alt="CellMap Flow"
20-
width="200"
21-
/>
22-
<button
23-
class="btn btn-outline-light"
24-
type="button"
25-
onclick="toggleDashboard()"
26-
>
16+
<img src="https://raw.githubusercontent.com/janelia-cellmap/cellmap-flow/refs/heads/main/img/CMFLOW_dark.png"
17+
alt="CellMap Flow" width="200" />
18+
<button class="btn btn-outline-light" type="button" onclick="toggleDashboard()">
2719
Toggle Dashboard
2820
</button>
2921
</div>
3022
</nav>
3123

3224
<div class="row bg-dark text-light" style="height: 100vh;">
3325
<div class="col-9" id="iframe-column">
34-
<iframe src="{{ neuroglancer_url }}" style="width:100%;height:100vh;" title="Example IFrame" id="my_iframe">
26+
<iframe src="{{ neuroglancer_url }}" style="width:100%;height:100vh;" title="Neuroglancer Data Viewer"
27+
id="my_iframe">
3528
Your browser does not support iframes.
3629
</iframe>
3730
</div>
3831
<div class="col-3" id="dashboard-column">
3932
{% include "_dashboard.html" %}
4033
</div>
41-
34+
4235
</div>
4336

44-
<script
45-
src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"
46-
></script>
47-
<script src="{{ url_for('static', filename='js/dashboard_script.js') }}"></script>
48-
<script src="https://cdn.jsdelivr.net/npm/skulpt@1.2.0/dist/skulpt.min.js"></script>
49-
<script src="https://cdn.jsdelivr.net/npm/skulpt@1.2.0/dist/skulpt-stdlib.js"></script>
50-
<!-- shared_submit.js or inside a <script> tag in your base layout -->
37+
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
38+
<script src="{{ url_for('static', filename='js/dashboard_script.js') }}"></script>
39+
<script src="https://cdn.jsdelivr.net/npm/skulpt@1.2.0/dist/skulpt.min.js"></script>
40+
<script src="https://cdn.jsdelivr.net/npm/skulpt@1.2.0/dist/skulpt-stdlib.js"></script>
41+
<!-- shared_submit.js or inside a <script> tag in your base layout -->
5142
<script>
52-
document.addEventListener('keydown', function(event) {
43+
document.addEventListener('keydown', function (event) {
5344
// Check if the key pressed is "Enter"
5445
if (event.key === "Enter") {
5546
// Trigger a click event on the button
@@ -59,27 +50,27 @@
5950
document.addEventListener("DOMContentLoaded", function () {
6051
// Find all elements with id="submitAll" (there could be 2 if both partials exist)
6152
const submitAllButtons = document.querySelectorAll("#submitAll");
62-
53+
6354
// A single function that merges data & sends the fetch
6455
function handleSubmitAll() {
6556
// 1) Gather data from input_norm partial
6657
const inputNorm = (window.gatherInputNormData)
6758
? window.gatherInputNormData()
6859
: {};
69-
60+
7061
// 2) Gather data from postprocess partial
7162
const postprocess = (window.gatherPostProcessData)
7263
? window.gatherPostProcessData()
7364
: {};
74-
65+
7566
// 3) Build final combined payload
7667
const finalPayload = {
7768
input_norm: inputNorm,
7869
postprocess: postprocess,
7970
};
80-
71+
8172
console.log("Combined Payload:", finalPayload);
82-
73+
8374
// 4) Submit via fetch
8475
fetch("/api/process", {
8576
method: "POST",
@@ -94,7 +85,7 @@
9485
if (logAreaNorm) {
9586
logAreaNorm.value += "Server response:\n" + JSON.stringify(data, null, 2) + "\n";
9687
}
97-
88+
9889
const logAreaPost = document.getElementById("submissionLog_postProcess");
9990
if (logAreaPost) {
10091
logAreaPost.value += "Server response:\n" + JSON.stringify(data, null, 2) + "\n";
@@ -105,13 +96,14 @@
10596
alert("Error submitting combined data");
10697
});
10798
}
108-
99+
109100
// Attach the same handler to each #submitAll
110101
submitAllButtons.forEach((btn) => {
111102
btn.addEventListener("click", handleSubmitAll);
112103
});
113104
});
114-
</script>
115-
105+
</script>
106+
116107
</body>
108+
117109
</html>

cellmap_flow/globals.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@ def __new__(cls):
2929
cls._instance.postprocess = []
3030
cls._instance.viewer = None
3131
cls._instance.dataset_path = None
32-
cls._instance.model_catalog = {}
32+
# cls._instance.model_catalog = {}
3333
# Uncomment and adjust if you want to load the model catalog:
34-
# cls._instance.model_catalog = load_model_paths(
35-
# os.path.normpath(os.path.join(os.path.dirname(__file__), os.pardir, "models", "models.yaml"))
36-
# )
34+
cls._instance.model_catalog = load_model_paths(
35+
os.path.normpath(
36+
os.path.join(
37+
os.path.dirname(__file__), os.pardir, "models", "models.yaml"
38+
)
39+
)
40+
)
3741
cls._instance.queue = "gpu_h100"
3842
cls._instance.charge_group = "cellmap"
3943
cls._instance.neuroglancer_thread = None

cellmap_flow/image_data_interface.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import os
2+
import zarr
13
from cellmap_flow.utils.ds import (
4+
find_closest_scale,
25
get_ds_info,
36
open_ds_tensorstore,
47
to_ndarray_tensorstore,
58
)
9+
import logging
10+
11+
logger = logging.getLogger(__name__)
612

713

814
class ImageDataInterface:
@@ -16,12 +22,19 @@ def __init__(
1622
concurrency_limit=1,
1723
normalize=True,
1824
):
25+
# if multiscale dataset, get scale for voxel size
26+
if not isinstance(zarr.open(dataset_path, mode="r"), zarr.core.Array):
27+
scale, _, _ = find_closest_scale(dataset_path, voxel_size)
28+
logger.info(f"found scale {scale} for voxel size {voxel_size}")
29+
dataset_path = os.path.join(dataset_path, scale)
30+
logger.info(f"using dataset path {dataset_path}")
1931
self.path = dataset_path
2032
self.filetype = (
2133
"zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else "n5"
2234
)
2335
self.swap_axes = self.filetype == "n5"
2436
self._ts = None
37+
2538
self.voxel_size, self.chunk_shape, self.shape, self.roi, self.swap_axes = (
2639
get_ds_info(dataset_path)
2740
)

cellmap_flow/norm/input_normalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __init__(self, mean=0.0, std=1.0):
182182
def dtype(self):
183183
return np.float32
184184

185-
def normalize(self, data: np.ndarray) -> np.ndarray:
185+
def _process(self, data: np.ndarray) -> np.ndarray:
186186
return (data - self.mean) / self.std
187187

188188

cellmap_flow/utils/data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def _get_config(self):
6666
from cellmap_flow.utils.load_py import load_safe_config
6767

6868
config = load_safe_config(self.script_path)
69+
if not hasattr(config, "block_shape"):
70+
setattr(
71+
config,
72+
"block_shape",
73+
np.array(tuple(config.write_shape) + (config.output_channels,)),
74+
)
75+
6976
return config
7077

7178

cellmap_flow/utils/ds.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
# %%
2-
import zarr
3-
from funlib.geometry import Coordinate
2+
import json
43
import logging
5-
import tensorstore as ts
6-
import numpy as np
7-
from funlib.geometry import Coordinate
8-
from funlib.geometry import Roi
94
import os
105
import re
11-
import zarr
12-
from skimage.measure import block_reduce
13-
from funlib.geometry import Coordinate, Roi
6+
from typing import Sequence, Union
147

8+
import h5py
9+
import numpy as np
10+
import s3fs
11+
import tensorstore as ts
1512
import zarr
13+
from funlib.geometry import Coordinate, Roi
14+
from skimage.measure import block_reduce
1615
from zarr.n5 import N5FSStore
17-
import h5py
18-
import json
19-
import logging
20-
import os
21-
from typing import Union, Sequence
16+
2217
from cellmap_flow.globals import g
23-
import s3fs
2418

2519

2620
def get_scale_info(zarr_grp):
@@ -49,6 +43,26 @@ def find_target_scale(zarr_grp_path, target_resolution):
4943
return target_scale, offsets[target_scale], shapes[target_scale]
5044

5145

46+
def find_closest_scale(zarr_grp_path, target_resolution):
47+
zarr_grp = zarr.open(zarr_grp_path, mode="r")
48+
offsets, resolutions, shapes = get_scale_info(zarr_grp)
49+
target_scale = None
50+
last_scale = None
51+
for scale, res in resolutions.items():
52+
if last_scale is None:
53+
last_scale = scale
54+
if Coordinate(res) == Coordinate(target_resolution):
55+
target_scale = scale
56+
break
57+
elif any((r > t for r, t in zip(res, target_resolution))):
58+
target_scale = last_scale
59+
break
60+
last_scale = scale
61+
if target_scale is None:
62+
target_scale = last_scale
63+
return target_scale, offsets[target_scale], shapes[target_scale]
64+
65+
5266
# Ensure tensorstore does not attempt to use GCE credentials
5367
os.environ["GCE_METADATA_ROOT"] = "metadata.google.internal.invalid"
5468

cellmap_flow/utils/load_py.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
## copied from https://github.com/janelia-cellmap/cellmap-segmentation-challenge/blob/a9525b31502abb7ea01e10c16340bbc1056cf1fc/src/cellmap_segmentation_challenge/utils/security.py
2-
1+
# copied from https://github.com/janelia-cellmap/cellmap-segmentation-challenge/blob/6e9d842b9a90b0df22aa07946a4d1deed5c27504/src/cellmap_segmentation_challenge/utils/security.py
32
import ast
43
import os
4+
import inspect
5+
from typing import Any
56

67
from upath import UPath
78

@@ -67,7 +68,9 @@ def load_safe_config(config_path, force_safe=os.getenv("FORCE_SAFE_CONFIG", Fals
6768
for issue in issues:
6869
print(f" - {issue}")
6970
if force_safe:
70-
raise ValueError("Unsafe script detected; loading aborted.")
71+
raise ValueError(
72+
"Unsafe script detected; loading aborted. You can set the environment variable FORCE_SAFE_CONFIG=False or pass force_safe=False to override."
73+
)
7174

7275
# Load the config module if script is safe
7376
config_path = UPath(config_path)
@@ -76,6 +79,22 @@ def load_safe_config(config_path, force_safe=os.getenv("FORCE_SAFE_CONFIG", Fals
7679
try:
7780
with open(config_path, "r") as config_file:
7881
code = config_file.read()
82+
# Parse the code into an AST
83+
tree = ast.parse(code)
84+
85+
# Define a node transformer to replace __file__ with the config path
86+
class ReplaceFileNode(ast.NodeTransformer):
87+
def visit_Name(self, node):
88+
if node.id == "__file__":
89+
return ast.Constant(value=str(config_path), kind=None)
90+
return node
91+
92+
# Transform the AST
93+
transformer = ReplaceFileNode()
94+
tree = transformer.visit(tree)
95+
96+
# Convert the modified AST back to source code
97+
code = ast.unparse(tree)
7998
exec(code, config_namespace)
8099
# Extract the config object from the namespace
81100
config = Config(**config_namespace)
@@ -91,3 +110,39 @@ def load_safe_config(config_path, force_safe=os.getenv("FORCE_SAFE_CONFIG", Fals
91110
class Config:
92111
def __init__(self, **kwargs):
93112
self.__dict__.update(kwargs)
113+
self.kwargs = kwargs
114+
115+
def to_dict(self):
116+
"""
117+
Returns the configuration as a dictionary.
118+
"""
119+
return self.kwargs
120+
121+
def serialize(self):
122+
"""
123+
Serializes the configuration to a string representation.
124+
"""
125+
serialized = {}
126+
for key, value in self.kwargs.items():
127+
if (
128+
inspect.ismodule(value)
129+
or inspect.isclass(value)
130+
or inspect.isfunction(value)
131+
or inspect.isbuiltin(value)
132+
):
133+
# Skip modules, classes, and functions
134+
continue
135+
elif "__" in key:
136+
# Skip private attributes
137+
continue
138+
elif not isinstance(value, (int, float, str, bool)):
139+
serialized[key] = str(value)
140+
else:
141+
serialized[key] = value
142+
return serialized
143+
144+
def get(self, key: str, default: Any = None) -> Any:
145+
"""
146+
Gets the value of a configuration key.
147+
"""
148+
return self.kwargs.get(key, default)

0 commit comments

Comments
 (0)