Skip to content

Commit 47397de

Browse files
committed
1. Update typos.
2. Support name_prefix. 3. Optimize the settings of axes for plotting curves. 4. Support the setting of the `sharey` for plotting curves. 5. Better support for `include_` and `exclude_`. 6. Other updates.
1 parent 9248a0d commit 47397de

File tree

8 files changed

+199
-113
lines changed

8 files changed

+199
-113
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,5 +275,6 @@ gen
275275
/untracked/
276276
/configs/
277277
/*.py
278+
/*.sh
278279
/results/rgb_sod.md
279280
/results/htmls/*.html

metrics/cal_sod_matrics.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ def cal_sod_matrics(
127127
# 真值名字列表
128128
gt_index_file = dataset_path.get("index_file")
129129
if gt_index_file:
130-
gt_name_list = get_name_list(data_path=gt_index_file, file_ext=gt_ext)
130+
gt_name_list = get_name_list(data_path=gt_index_file, name_suffix=gt_ext)
131131
else:
132-
gt_name_list = get_name_list(data_path=gt_root, file_ext=gt_ext)
132+
gt_name_list = get_name_list(data_path=gt_root, name_suffix=gt_ext)
133133
assert len(gt_name_list) > 0, "there is not ground truth."
134134

135135
# ==>> test the intersection between pre and gt for each method <<==
@@ -147,9 +147,14 @@ def cal_sod_matrics(
147147
continue
148148

149149
# 预测结果存放路径下的图片文件名字列表和扩展名称
150-
pre_ext = method_dataset_info["suffix"]
150+
pre_prefix = method_dataset_info.get("prefix", "")
151+
pre_suffix = method_dataset_info["suffix"]
151152
pre_root = method_dataset_info["path"]
152-
pre_name_list = get_name_list(data_path=pre_root, file_ext=pre_ext)
153+
pre_name_list = get_name_list(
154+
data_path=pre_root,
155+
name_prefix=pre_prefix,
156+
name_suffix=pre_suffix,
157+
)
153158

154159
# get the intersection
155160
eval_name_list = sorted(list(set(gt_name_list).intersection(pre_name_list)))
@@ -163,7 +168,8 @@ def cal_sod_matrics(
163168
names=eval_name_list,
164169
num_bits=num_bits,
165170
pre_root=pre_root,
166-
pre_ext=pre_ext,
171+
pre_prefix=pre_prefix,
172+
pre_suffix=pre_suffix,
167173
gt_root=gt_root,
168174
gt_ext=gt_ext,
169175
desc=f"[{dataset_name}({len(gt_name_list)}):{method_name}({len(pre_name_list)})]",
@@ -186,7 +192,16 @@ def cal_sod_matrics(
186192

187193

188194
def evaluate_data(
189-
names, num_bits, gt_root, gt_ext, pre_root, pre_ext, desc="", proc_idx=None, blocking=True
195+
names,
196+
num_bits,
197+
gt_root,
198+
gt_ext,
199+
pre_root,
200+
pre_prefix,
201+
pre_suffix,
202+
desc="",
203+
proc_idx=None,
204+
blocking=True,
190205
):
191206
metric_recoder = MetricRecorder()
192207
# https://github.com/tqdm/tqdm#parameters
@@ -204,7 +219,8 @@ def evaluate_data(
204219
gt_root=gt_root,
205220
pre_root=pre_root,
206221
img_name=name,
207-
pre_ext=pre_ext,
222+
pre_prefix=pre_prefix,
223+
pre_suffix=pre_suffix,
208224
gt_ext=gt_ext,
209225
to_normalize=False,
210226
)

metrics/draw_curves.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def draw_curves(
1818
dataset_alias: dict = None,
1919
font_cfg: dict = None,
2020
subplots_cfg: dict = None,
21+
separated_legend: bool = False,
22+
sharey: bool = False,
2123
):
2224
if dataset_alias is None:
2325
dataset_alias = {}
@@ -38,14 +40,25 @@ def draw_curves(
3840

3941
curve_drawer = CurveDrawer(
4042
row_num=row_num,
41-
col_num=math.ceil(len(dataset_info.keys()) / row_num),
43+
num_subplots=len(dataset_info.keys()),
4244
font_cfg=font_cfg,
4345
subplots_cfg=subplots_cfg,
46+
separated_legend=separated_legend,
47+
sharey=sharey,
4448
)
4549

4650
for idx, dataset_name in enumerate(dataset_info.keys()):
4751
# 与cfg[dataset_info]中的key保持一致
4852
dataset_results = curves[dataset_name]
53+
curve_drawer.set_axis_property(
54+
idx=idx,
55+
title=dataset_alias.get(dataset_name, dataset_name).upper(),
56+
x_label=x_label,
57+
y_label=y_label,
58+
x_lim=x_lim,
59+
y_lim=y_lim,
60+
)
61+
4962
for method_name, method_info in drawing_info.items():
5063
# 与cfg[drawing_info]中的key保持一致
5164
method_results = dataset_results.get(method_name, None)
@@ -65,15 +78,10 @@ def draw_curves(
6578
y_data = method_results["fm"]
6679
x_data = np.linspace(0, 1, 256)
6780

68-
curve_drawer.draw_method_curve(
69-
curr_idx=idx,
70-
dataset_name=dataset_alias.get(dataset_name, dataset_name).upper(),
81+
curve_drawer.plot_at_axis(
82+
idx=idx,
7183
method_curve_setting=method_info["curve_setting"],
72-
x_label=x_label,
73-
y_label=y_label,
7484
x_data=x_data,
7585
y_data=y_data,
76-
x_lim=x_lim,
77-
y_lim=y_lim,
7886
)
7987
curve_drawer.show()

tools/append_results.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ def get_args():
88
parser = argparse.ArgumentParser(description="A simple tool for merging two npy file.")
99
parser.add_argument("--old-npy", type=str, required=True)
1010
parser.add_argument("--new-npy", type=str, required=True)
11+
parser.add_argument("--method-names", type=str, nargs="+")
12+
parser.add_argument("--dataset-names", type=str, nargs="+")
1113
parser.add_argument("--out-npy", type=str, required=True)
1214
args = parser.parse_args()
1315
return args
@@ -19,13 +21,21 @@ def main():
1921
old_npy: dict = np.load(args.old_npy, allow_pickle=True).item()
2022

2123
for dataset_name, methods_info in new_npy.items():
24+
if args.dataset_names and dataset_name not in args.dataset_names:
25+
continue
26+
2227
print(f"[PROCESSING INFORMATION ABOUT DATASET {dataset_name}...]")
2328
old_methods_info = old_npy.get(dataset_name)
2429
if not old_methods_info:
2530
raise KeyError(f"{old_npy} doesn't contain the information about {dataset_name}.")
31+
2632
print(f"OLD_NPY: {list(old_methods_info.keys())}")
2733
print(f"NEW_NPY: {list(methods_info.keys())}")
34+
2835
for method_name, method_info in methods_info.items():
36+
if args.method_names and method_name not in args.method_names:
37+
continue
38+
2939
if method_name not in old_npy[dataset_name]:
3040
old_methods_info[method_name] = method_info
3141
print(f"MERGED_NPY: {list(old_methods_info.keys())}")

tools/check_path.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,17 @@
2222

2323
total_msgs = []
2424
for method_name, method_info in methods_info.items():
25-
for dataset_name, resutls_info in method_info.items():
26-
if resutls_info is None:
25+
for dataset_name, results_info in method_info.items():
26+
if results_info is None:
2727
continue
2828

2929
dataset_mask_info = datasets_info[dataset_name]["mask"]
3030
mask_path = dataset_mask_info["path"]
3131
mask_suffix = dataset_mask_info["suffix"]
3232

33-
dir_path = resutls_info["path"]
34-
file_suffix = resutls_info["suffix"]
33+
dir_path = results_info["path"]
34+
file_prefix = results_info.get("prefix", "")
35+
file_suffix = results_info["suffix"]
3536

3637
if not os.path.exists(dir_path):
3738
total_msgs.append(f"{dir_path} 不存在")
@@ -41,12 +42,12 @@
4142
continue
4243
else:
4344
pred_names = [
44-
name[: -len(file_suffix)]
45+
name[len(file_prefix) : -len(file_suffix)]
4546
for name in os.listdir(dir_path)
46-
if name.endswith(file_suffix)
47+
if name.startswith(file_prefix) and name.endswith(file_suffix)
4748
]
4849
if len(pred_names) == 0:
49-
total_msgs.append(f"{dir_path} 中不包含后缀为{file_suffix}的文件")
50+
total_msgs.append(f"{dir_path} 中不包含前缀为{file_prefix}且后缀为{file_suffix}的文件")
5051
continue
5152

5253
mask_names = [
@@ -59,7 +60,9 @@
5960
total_msgs.append(f"{dir_path} 中数据名字与真值 {mask_path} 不匹配")
6061
elif len(intersection_names) != len(mask_names):
6162
difference_names = set(mask_names).difference(pred_names)
62-
total_msgs.append(f"{dir_path} 中数据{difference_names}与真值 {mask_path} 不一致")
63+
total_msgs.append(
64+
f"{dir_path} 中数据({len(list(pred_names))})与真值({len(list(mask_names))})不一致"
65+
)
6366

6467
if total_msgs:
6568
print(*total_msgs, sep="\n")

utils/generate_info.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,30 @@ def _template_generator(method_info: dict, method_name: str) -> dict:
5454
return _template_generator
5555

5656

57+
def get_valid_elements(
58+
source: list,
59+
include_elements: list = None,
60+
exclude_elements: list = None,
61+
):
62+
if include_elements is None:
63+
include_elements = []
64+
if exclude_elements is None:
65+
exclude_elements = []
66+
assert not set(include_elements).intersection(
67+
exclude_elements
68+
), "`include_elements` and `exclude_elements` must have no intersection."
69+
70+
targeted = set(source).difference(exclude_elements)
71+
assert targeted, "`exclude_elements can not include all datasets."
72+
73+
if include_elements:
74+
# include_elements: [] or [dataset1_name, dataset2_name, ...]
75+
# only latter will be used to select datasets from `targeted`
76+
targeted = targeted.intersection(include_elements)
77+
78+
return list(targeted)
79+
80+
5781
def get_methods_info(
5882
methods_info_json: str,
5983
for_drawing: bool = False,
@@ -72,38 +96,31 @@ def get_methods_info(
7296
:return: methods_full_info
7397
"""
7498

75-
assert os.path.exists(methods_info_json) and os.path.isfile(
76-
methods_info_json
77-
), methods_info_json
78-
if include_methods and exclude_methods:
79-
raise ValueError("include_methods、exclude_methods 不可以同时非None")
80-
99+
assert os.path.isfile(methods_info_json), methods_info_json
81100
with open(methods_info_json, encoding="utf-8", mode="r") as f:
82-
methods_info = json.load(f, object_pairs_hook=OrderedDict) # 有序载入
83-
84-
if include_methods:
85-
for method_name in include_methods:
86-
if method_name not in methods_info:
87-
raise ValueError(f"The info of {method_name} is not in the methods_info_json.")
88-
if exclude_methods:
89-
for method_name in exclude_methods:
90-
if method_name not in methods_info:
91-
raise ValueError(f"The info of {method_name} is not in the methods_info_json.")
101+
methods_info = json.load(f, object_hook=OrderedDict) # 有序载入
92102

93103
if our_name:
94104
assert our_name in methods_info, f"{our_name} is not in json file."
95105

106+
targeted_methods = get_valid_elements(
107+
source=list(methods_info.keys()),
108+
include_elements=include_methods,
109+
exclude_elements=exclude_methods,
110+
)
111+
if our_name and our_name in targeted_methods:
112+
targeted_methods.pop(targeted_methods.index(our_name))
113+
targeted_methods.sort()
114+
targeted_methods.insert(0, our_name)
115+
96116
if for_drawing:
97117
info_generator = curve_info_generator()
98118
else:
99119
info_generator = simple_info_generator()
100120

101121
methods_full_info = []
102-
for method_name, method_path in methods_info.items():
103-
if include_methods and (method_name not in include_methods):
104-
continue
105-
if exclude_methods and (method_name in exclude_methods):
106-
continue
122+
for method_name in targeted_methods:
123+
method_path = methods_info[method_name]
107124

108125
if for_drawing and our_name and our_name == method_name:
109126
method_info = info_generator(method_path, method_name, line_color="red", line_width=3)
@@ -114,7 +131,9 @@ def get_methods_info(
114131

115132

116133
def get_datasets_info(
117-
datastes_info_json: str, include_datasets: list = None, exclude_datasets: list = None
134+
datastes_info_json: str,
135+
include_datasets: list = None,
136+
exclude_datasets: list = None,
118137
) -> OrderedDict:
119138
"""
120139
在json文件中存储的所有数据集的信息会被直接导出到一个字典中
@@ -125,30 +144,20 @@ def get_datasets_info(
125144
:return: datastes_full_info
126145
"""
127146

128-
assert os.path.exists(datastes_info_json) and os.path.isfile(
129-
datastes_info_json
130-
), datastes_info_json
131-
if include_datasets and exclude_datasets:
132-
raise ValueError("include_methods、exclude_methods 不可以同时非None")
133-
147+
assert os.path.isfile(datastes_info_json), datastes_info_json
134148
with open(datastes_info_json, encoding="utf-8", mode="r") as f:
135-
datasets_info = json.load(f, object_pairs_hook=OrderedDict) # 有序载入
149+
datasets_info = json.load(f, object_hook=OrderedDict) # 有序载入
136150

137-
if include_datasets:
138-
for dataset_name in include_datasets:
139-
if dataset_name not in datasets_info:
140-
raise ValueError(f"The info of {dataset_name} is not in the datasets_info_json.")
141-
if exclude_datasets:
142-
for dataset_name in exclude_datasets:
143-
if dataset_name not in datasets_info:
144-
raise ValueError(f"The info of {dataset_name} is not in the methods_info_json.")
151+
targeted_datasets = get_valid_elements(
152+
source=list(datasets_info.keys()),
153+
include_elements=include_datasets,
154+
exclude_elements=exclude_datasets,
155+
)
156+
targeted_datasets.sort()
145157

146158
datasets_full_info = []
147-
for dataset_name, data_path in datasets_info.items():
148-
if include_datasets and (dataset_name not in include_datasets):
149-
continue
150-
if exclude_datasets and (dataset_name in exclude_datasets):
151-
continue
159+
for dataset_name in targeted_datasets:
160+
data_path = datasets_info[dataset_name]
152161

153162
datasets_full_info.append((dataset_name, data_path))
154163
return OrderedDict(datasets_full_info)

utils/misc.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_name_list_and_suffix(data_path: str) -> tuple:
4646
return name_list, file_ext
4747

4848

49-
def get_name_list(data_path: str, file_ext: str = None) -> list:
49+
def get_name_list(data_path: str, name_prefix: str = "", name_suffix: str = "") -> list:
5050
if os.path.isfile(data_path):
5151
assert data_path.endswith((".txt", ".lst"))
5252
data_list = []
@@ -55,15 +55,18 @@ def get_name_list(data_path: str, file_ext: str = None) -> list:
5555
while line:
5656
data_list.append(line)
5757
line = f.readline().strip()
58-
file_ext = None # 使用name list时不需要指定ext了
5958
else:
6059
data_list = os.listdir(data_path)
6160

62-
if file_ext is not None:
63-
# 如果提供file_ext,则基于file_ext来截断文件名,这可以用来应对具有额外名称后缀的数据
64-
name_list = [f[: -len(file_ext)] for f in data_list if f.endswith(file_ext)]
61+
name_list = data_list
62+
if not name_prefix and not name_suffix:
63+
name_list = [os.path.splitext(f)[0] for f in name_list]
6564
else:
66-
name_list = [os.path.splitext(f)[0] for f in data_list]
65+
name_list = [
66+
f[len(name_prefix) : -len(name_suffix)]
67+
for f in name_list
68+
if f.startswith(name_prefix) and f.endswith(name_suffix)
69+
]
6770

6871
name_list = list(set(name_list))
6972
return name_list
@@ -174,11 +177,12 @@ def get_gt_pre_with_name(
174177
gt_root: str,
175178
pre_root: str,
176179
img_name: str,
177-
pre_ext: str,
180+
pre_prefix: str,
181+
pre_suffix: str,
178182
gt_ext: str = ".png",
179183
to_normalize: bool = False,
180184
):
181-
img_path = os.path.join(pre_root, img_name + pre_ext)
185+
img_path = os.path.join(pre_root, pre_prefix + img_name + pre_suffix)
182186
gt_path = os.path.join(gt_root, img_name + gt_ext)
183187

184188
pre = imread_wich_checking(img_path, for_color=False)

0 commit comments

Comments
 (0)