Skip to content

Commit 59735d0

Browse files
committed
1. 修复绘图代码中的一些问题。
2. 完善对于 E-measure 绘图的支持。 3. 补充一些绘图的展示,这里以我自己的 RGB-D SOD 论文 CAVER (TIP 2023) 的论文结果为例。
1 parent 6c97b0c commit 59735d0

File tree

5 files changed

+57
-14
lines changed

5 files changed

+57
-14
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,4 @@ gen
282282
/*.sh
283283
/results/rgb_sod.md
284284
/results/htmls/*.html
285+
!/.github/assets/*.jpg

metrics/draw_curves.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def draw_curves(
12-
for_pr: bool = True,
12+
mode: str,
1313
axes_setting: dict = None,
1414
curves_npy_path: list = None,
1515
row_num: int = 1,
@@ -20,14 +20,13 @@ def draw_curves(
2020
ncol_of_legend: int = 1,
2121
separated_legend: bool = False,
2222
sharey: bool = False,
23-
line_styles=("-", "--"),
2423
line_width=3,
2524
save_name=None,
2625
):
2726
"""A better curve painter!
2827
2928
Args:
30-
for_pr (bool, optional): Plot for PR curves or FM curves. Defaults to True.
29+
mode (str): `pr` for PR curves, `fm` for F-measure curves, and `em' for E-measure curves.
3130
axes_setting (dict, optional): Setting for axes. Defaults to None.
3231
curves_npy_path (list, optional): Paths of curve npy files. Defaults to None.
3332
row_num (int, optional): Number of rows. Defaults to 1.
@@ -38,11 +37,10 @@ def draw_curves(
3837
ncol_of_legend (int, optional): Number of columns for the legend. Defaults to 1.
3938
separated_legend (bool, optional): Use the separated legend. Defaults to False.
4039
sharey (bool, optional): Use a shared y-axis. Defaults to False.
41-
line_styles (tuple, optional): Styles of lines. Defaults to ("-", "--").
4240
line_width (int, optional): Width of lines. Defaults to 3.
4341
save_name (str, optional): Name or path (without the extension format). Defaults to None.
4442
"""
45-
mode = "pr" if for_pr else "fm"
43+
assert mode in ["pr", "fm", "em"]
4644
save_name = save_name or mode
4745
mode_axes_setting = axes_setting[mode]
4846

@@ -97,23 +95,36 @@ def draw_curves(
9795
# assert len(our_methods) <= len(line_styles)
9896
else:
9997
our_methods = []
98+
num_our_methods = len(our_methods)
10099

101-
# Give each method a unique color.
100+
# Give each method a unique color and style.
102101
color_table = sorted(
103102
[
104103
color
105104
for name, color in colors.cnames.items()
106105
if name not in ["red", "white"] or not name.startswith("light") or "gray" in name
107106
]
108107
)
108+
style_table = ["-", "--", "-.", ":", "."]
109+
109110
unique_method_settings = OrderedDict()
110111
for i, method_name in enumerate(target_unique_method_names):
112+
if i < num_our_methods:
113+
line_color = "red"
114+
line_style = style_table[i % len(style_table)]
115+
else:
116+
other_idx = i - num_our_methods
117+
line_color = color_table[other_idx]
118+
line_style = style_table[other_idx % 2]
119+
111120
unique_method_settings[method_name] = {
112-
"line_color": "red" if i < len(our_methods) else color_table[i],
121+
"line_color": line_color,
113122
"line_label": method_aliases.get(method_name, method_name),
114-
"line_style": line_styles[i % len(line_styles)],
123+
"line_style": line_style,
115124
"line_width": line_width,
116125
}
126+
# ensure that our methods are drawn last to avoid being overwritten by other methods
127+
target_unique_method_names.reverse()
117128

118129
curve_drawer = CurveDrawer(
119130
row_num=row_num,
@@ -135,9 +146,13 @@ def draw_curves(
135146
y_ticks=y_ticks,
136147
)
137148

138-
for method_name, method_setting in unique_method_settings.items():
149+
for method_name in target_unique_method_names:
150+
method_setting = unique_method_settings[method_name]
151+
139152
if method_name not in dataset_results:
140-
raise KeyError(f"{method_name} not in {sorted(dataset_results.keys())}")
153+
print(f"{method_name} will be skipped for {dataset_name}!")
154+
continue
155+
141156
method_results = dataset_results[method_name]
142157
if mode == "pr":
143158
y_data = method_results.get("p")

plot.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_args():
7373
parser.add_argument(
7474
"--mode",
7575
type=str,
76-
choices=["pr", "fm"],
76+
choices=["pr", "fm", "em"],
7777
default="pr",
7878
help="Mode for plotting. Default: pr",
7979
)
@@ -96,7 +96,7 @@ def main(args):
9696
dataset_aliases = aliases.get("dataset")
9797

9898
draw_curves.draw_curves(
99-
for_pr=args.mode == "pr",
99+
mode=args.mode,
100100
# 不同曲线的绘图配置
101101
axes_setting={
102102
# pr曲线的配置
@@ -113,6 +113,13 @@ def main(args):
113113
"x_ticks": np.linspace(0, 1, 6),
114114
"y_ticks": np.linspace(0.6, 1, 6),
115115
},
116+
# em曲线的配置
117+
"em": {
118+
"x_label": "Threshold",
119+
"y_label": r"E$_{m}$",
120+
"x_ticks": np.linspace(0, 1, 6),
121+
"y_ticks": np.linspace(0.7, 1, 6),
122+
},
116123
},
117124
curves_npy_path=args.curves_npys,
118125
row_num=args.num_rows,

readme.md

+19-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ A Python-based image binary segmentation evaluation toolbox.
176176

177177
### 为灰度图像的评估绘制曲线
178178

179-
可以使用 `plot.py` 来读取 `.npy` 文件按需对指定方法和数据集的结果整理并绘制 `PR` 曲线和 `Fm` 曲线. 该脚本用法可见 `python plot.py --help` 的输出. 按照自己需求添加配置项并执行即可.
179+
可以使用 `plot.py` 来读取 `.npy` 文件按需对指定方法和数据集的结果整理并绘制 `PR` , `F-measure``E-measure` 曲线. 该脚本用法可见 `python plot.py --help` 的输出. 按照自己需求添加配置项并执行即可.
180180

181181
最基本的一条是请按照子图数量, 合理地指定配置文件中的 `figure.figsize` 项的数值.
182182

@@ -223,6 +223,20 @@ python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-n
223223
python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-npys output/rgb_sod/curves.npy --our-methods MINet_R50_2020 --num-col-legend 1 --mode pr --separated-legend --sharey --save-name output/rgb_sod/complex_curve_pr
224224
```
225225

226+
## 绘图示例
227+
228+
**Precision-Recall Curve**:
229+
230+
![PRCurves](https://user-images.githubusercontent.com/26847524/227249768-a41ef076-6355-4b96-a291-fc0e071d9d35.jpg)
231+
232+
**F-measure Curve**:
233+
234+
![fm-curves](https://user-images.githubusercontent.com/26847524/227249746-f61d7540-bb73-464d-bccf-9a36323dec47.jpg)
235+
236+
**E-measure Curve**:
237+
238+
![em-curves](https://user-images.githubusercontent.com/26847524/227249727-8323d5cf-ddd7-427b-8152-b8f47781c4e3.jpg)
239+
226240
## 相关文献
227241

228242
```text
@@ -282,6 +296,10 @@ python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-n
282296

283297
## 更新日志
284298

299+
* 2023年3月23日
300+
1. 修复绘图代码中的一些问题。
301+
2. 完善对于 E-measure 绘图的支持。
302+
3. 补充一些绘图的展示,这里以我自己的 RGB-D SOD 论文 CAVER (TIP 2023) 的论文结果为例。
285303
* 2023年3月20日
286304
1. 提供更丰富的指标的支持。
287305
2. 更新`readme.md`和示例文件。

utils/print_formatter.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def formatter_for_tabulate(
9090
table = []
9191
headers = ["methods"]
9292
for method_name in method_names:
93-
metric_info = dataset_metrics[method_name]
93+
metric_info = dataset_metrics.get(method_name)
94+
if metric_info is None:
95+
continue
9496

9597
if method_name_length:
9698
method_name = clip_string(method_name, max_length=method_name_length, mode="left")

0 commit comments

Comments
 (0)