Skip to content

Commit 897612b

Browse files
committed
* 2021年11月18日
1. 改正拼写错误,调整命名。 2. 支持预测结果中使用名称前缀 (例子可见`examples`文件夹中的`config_method_json_example.json`),现在搭配后缀,基本上可以应对所有可能的情形了。但是需要说明的是,目前不支持使用文件提供的映射关系,请确保预测名字中包含真值(不包含扩展名)名字。 3. 优化了绘图中的axis的设置,由于这些设置属于非常细粒度的设定,目前暂不支持使用终端选项配置,之后可能会使用特定的配置文件,例如json等来配置相关选项。 4. 支持绘图中使用共享的纵轴,即`sharey`,这可以用来辅助绘制独立的示例图。具体使用可见`examples`中的`plot_results.py`文件。 5. 优化了下 `include_` 与 `exclude_` 类选项的相关函数. 6. 添加了数据集和方法配置的json的例子。并且针对`examples`中提供的配置文件统一命名为`config_`. 7. 绘图支持对数据名和方法名使用别名。之前都是直接从各自的 `json` 配置文件中读取键来作为绘图中显示的名字,这对于名字有特殊标记(例如名字中想补充年份或者会议名字)时的使用不太方便和灵活。所以当前支持了使用额外的 `json` 配置文件来配置映射关系。例子可见 `examples` 中的 `alias_for_plotting.json` 。 8. 由于核心文件`eval_all.py`和`plot_results.py`的配置和调用方式发生了变化,所以为了便于大家的使用和修改,我提供了两个简单调用的`sh`文件,里面提供了这怒地各个选项的基本配置案例。linux用户可以直接使用`bash <sh_name>.sh`来执行,而windows用户麻烦些,还是自己参考着其中的配置项在终端自行配置吧!有问题欢迎提问,当然,如果大家可以提供windows直接调用的`bat`文件倒也欢迎PR哦!
1 parent 47397de commit 897612b

12 files changed

+640
-190
lines changed

examples/alias_for_plotting.json

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"dataset": {
3+
"Name_In_Json": "Name_In_SubFigure",
4+
"NJUD": "NJUD",
5+
"NLPR": "NLPR",
6+
"DUTRGBD": "DUTRGBD",
7+
"STEREO1000": "SETERE",
8+
"RGBD135": "RGBD135",
9+
"SSD": "SSD",
10+
"SIP": "SIP"
11+
},
12+
"method": {
13+
"Name_In_Json": "Name_In_Legend",
14+
"GateNet_2020": "GateNet",
15+
"MINet_R50_2020": "MINet"
16+
}
17+
}
+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
{
2+
"LFSD": {
3+
"root": "Path_Of_RGBDSOD_Datasets/LFSD",
4+
"image": {
5+
"path": "Path_Of_RGBDSOD_Datasets/LFSD/Image",
6+
"suffix": ".jpg"
7+
},
8+
"mask": {
9+
"path": "Path_Of_RGBDSOD_Datasets/LFSD/Mask",
10+
"suffix": ".png"
11+
}
12+
},
13+
"NJUD": {
14+
"root": "Path_Of_RGBDSOD_Datasets/NJUD_FULL",
15+
"image": {
16+
"path": "Path_Of_RGBDSOD_Datasets/NJUD_FULL/Image",
17+
"suffix": ".jpg"
18+
},
19+
"mask": {
20+
"path": "Path_Of_RGBDSOD_Datasets/NJUD_FULL/Mask",
21+
"suffix": ".png"
22+
}
23+
},
24+
"NLPR": {
25+
"root": "Path_Of_RGBDSOD_Datasets/NLPR_FULL",
26+
"image": {
27+
"path": "Path_Of_RGBDSOD_Datasets/NLPR_FULL/Image",
28+
"suffix": ".jpg"
29+
},
30+
"mask": {
31+
"path": "Path_Of_RGBDSOD_Datasets/NLPR_FULL/Mask",
32+
"suffix": ".png"
33+
}
34+
},
35+
"RGBD135": {
36+
"root": "Path_Of_RGBDSOD_Datasets/RGBD135",
37+
"image": {
38+
"path": "Path_Of_RGBDSOD_Datasets/RGBD135/Image",
39+
"suffix": ".jpg"
40+
},
41+
"mask": {
42+
"path": "Path_Of_RGBDSOD_Datasets/RGBD135/Mask",
43+
"suffix": ".png"
44+
}
45+
},
46+
"SIP": {
47+
"root": "Path_Of_RGBDSOD_Datasets/SIP",
48+
"image": {
49+
"path": "Path_Of_RGBDSOD_Datasets/SIP/Image",
50+
"suffix": ".jpg"
51+
},
52+
"mask": {
53+
"path": "Path_Of_RGBDSOD_Datasets/SIP/Mask",
54+
"suffix": ".png"
55+
}
56+
},
57+
"SSD": {
58+
"root": "Path_Of_RGBDSOD_Datasets/SSD",
59+
"image": {
60+
"path": "Path_Of_RGBDSOD_Datasets/SSD/Image",
61+
"suffix": ".jpg"
62+
},
63+
"mask": {
64+
"path": "Path_Of_RGBDSOD_Datasets/SSD/Mask",
65+
"suffix": ".png"
66+
}
67+
},
68+
"STEREO797": {
69+
"root": "Path_Of_RGBDSOD_Datasets/STEREO797",
70+
"image": {
71+
"path": "Path_Of_RGBDSOD_Datasets/STEREO797/Image",
72+
"suffix": ".jpg"
73+
},
74+
"mask": {
75+
"path": "Path_Of_RGBDSOD_Datasets/STEREO797/Mask",
76+
"suffix": ".png"
77+
}
78+
},
79+
"STEREO1000": {
80+
"root": "Path_Of_RGBDSOD_Datasets/STEREO1000",
81+
"image": {
82+
"path": "Path_Of_RGBDSOD_Datasets/STEREO1000/Image",
83+
"suffix": ".jpg"
84+
},
85+
"mask": {
86+
"path": "Path_Of_RGBDSOD_Datasets/STEREO1000/Mask",
87+
"suffix": ".png"
88+
}
89+
},
90+
"DUTRGBD": {
91+
"root": "Path_Of_RGBDSOD_Datasets/DUT-RGBD/Test",
92+
"image": {
93+
"path": "Path_Of_RGBDSOD_Datasets/DUT-RGBD/Test/Image",
94+
"suffix": ".jpg"
95+
},
96+
"mask": {
97+
"path": "Path_Of_RGBDSOD_Datasets/DUT-RGBD/Test/Mask",
98+
"suffix": ".png"
99+
}
100+
}
101+
}
File renamed without changes.
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
{
2+
"Method1": {
3+
"PASCAL-S": {
4+
"path": "Path_Of_Method1/PASCAL-S/DGRL",
5+
"suffix": ".png"
6+
},
7+
"ECSSD": {
8+
"path": "Path_Of_Method1/ECSSD/DGRL",
9+
"suffix": ".png"
10+
},
11+
"HKU-IS": {
12+
"path": "Path_Of_Method1/HKU-IS/DGRL",
13+
"suffix": ".png"
14+
},
15+
"DUT-OMRON": {
16+
"path": "Path_Of_Method1/DUT-OMRON/DGRL",
17+
"suffix": ".png"
18+
},
19+
"DUTS-TE": {
20+
"path": "Path_Of_Method1/DUTS-TE/DGRL",
21+
"suffix": ".png"
22+
}
23+
},
24+
"Method2": {
25+
"PASCAL-S": {
26+
"path": "Path_Of_Method2/pascal",
27+
"prefix": "pascal_",
28+
"suffix": ".png"
29+
},
30+
"ECSSD": {
31+
"path": "Path_Of_Method2/ecssd",
32+
"prefix": "ecssd_",
33+
"suffix": ".png"
34+
},
35+
"HKU-IS": {
36+
"path": "Path_Of_Method2/hku",
37+
"prefix": "hku_",
38+
"suffix": ".png"
39+
},
40+
"DUT-OMRON": {
41+
"path": "Path_Of_Method2/duto",
42+
"prefix": "duto_",
43+
"suffix": ".png"
44+
},
45+
"DUTS-TE": {
46+
"path": "Path_Of_Method2/dut_te",
47+
"prefix": "dut_te_",
48+
"suffix": ".png"
49+
}
50+
},
51+
"Method3": {
52+
"PASCAL-S": {
53+
"path": "Path_Of_Method3/pascal",
54+
"prefix": "pascal_",
55+
"suffix": "_fused_sod.png"
56+
},
57+
"ECSSD": {
58+
"path": "Path_Of_Method3/ecssd",
59+
"prefix": "ecssd_",
60+
"suffix": "_fused_sod.png"
61+
},
62+
"HKU-IS": {
63+
"path": "Path_Of_Method3/hku",
64+
"prefix": "hku_",
65+
"suffix": "_fused_sod.png"
66+
},
67+
"DUT-OMRON": {
68+
"path": "Path_Of_Method3/duto",
69+
"prefix": "duto_",
70+
"suffix": "_fused_sod.png"
71+
},
72+
"DUTS-TE": {
73+
"path": "Path_Of_Method3/dut_te",
74+
"prefix": "dut_te_",
75+
"suffix": "_fused_sod.png"
76+
}
77+
}
78+
}
File renamed without changes.

examples/eval_all.py

+98-54
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,114 @@
11
# -*- coding: utf-8 -*-
2+
import argparse
23
import os
4+
import sys
5+
import warnings
6+
7+
sys.path.append("..")
38

49
from metrics import cal_sod_matrics
510
from utils.generate_info import get_datasets_info, get_methods_info
611
from utils.misc import make_dir
712

8-
"""
9-
Include: Fm Curve/PR Curves/MAE/(max/mean/weighted) Fmeasure/Smeasure/Emeasure
1013

11-
NOTE:
12-
* Our method automatically calculates the intersection of `pre` and `gt`.
13-
But it needs to have uniform naming rules for `pre` and `gt`.
14-
"""
14+
def get_args():
15+
parser = argparse.ArgumentParser(
16+
description="""Include: Fm Curve/PR Curves/MAE/(max/mean/weighted) Fmeasure/Smeasure/Emeasure
17+
NOTE:
18+
Our method automatically calculates the intersection of `pre` and `gt`.
19+
Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext`
20+
""",
21+
formatter_class=argparse.RawTextHelpFormatter,
22+
)
23+
parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.")
24+
parser.add_argument("--method-json", required=True, type=str, help="Json file for methods.")
25+
parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.")
26+
parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.")
27+
parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.")
28+
parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.")
29+
parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.")
30+
parser.add_argument(
31+
"--include-methods",
32+
type=str,
33+
nargs="+",
34+
help="Names of only specific methods you want to evaluate.",
35+
)
36+
parser.add_argument(
37+
"--exclude-methods",
38+
type=str,
39+
nargs="+",
40+
help="Names of some specific methods you do not want to evaluate.",
41+
)
42+
parser.add_argument(
43+
"--include-datasets",
44+
type=str,
45+
nargs="+",
46+
help="Names of only specific datasets you want to evaluate.",
47+
)
48+
parser.add_argument(
49+
"--exclude-datasets",
50+
type=str,
51+
nargs="+",
52+
help="Names of some specific datasets you do not want to evaluate.",
53+
)
54+
parser.add_argument(
55+
"--num-workers",
56+
type=int,
57+
default=4,
58+
help="Number of workers for multi-threading or multi-processing. Default: 4",
59+
)
60+
parser.add_argument(
61+
"--num-bits",
62+
type=int,
63+
default=3,
64+
help="Number of decimal places for showing results. Default: 3",
65+
)
66+
args = parser.parse_args()
1567

16-
total_info = dict(
17-
rgb_sod=dict(
18-
dataset="/home/lart/Coding/GIT/PySODEvalToolkit/configs/datasets/json/rgb_sod.json",
19-
method="/home/lart/Coding/GIT/PySODEvalToolkit/configs/methods/json/rgb_sod_methods.json",
20-
),
21-
rgb_cod=dict(
22-
dataset="/home/lart/Coding/GIT/PySODEvalToolkit/configs/datasets/json/rgb_cod.json",
23-
method="/home/lart/Coding/GIT/PySODEvalToolkit/configs/methods/json/rgb_cod_methods.json",
24-
),
25-
rgbd_sod=dict(
26-
dataset="/home/lart/Coding/GIT/PySODEvalToolkit/configs/datasets/json/rgbd_sod.json",
27-
method="/home/lart/Coding/GIT/PySODEvalToolkit/configs/methods/json/rgbd_sod_methods_ablation.json",
28-
),
29-
)
68+
if args.metric_npy is not None:
69+
make_dir(os.path.dirname(args.metric_npy))
70+
if args.curves_npy is not None:
71+
make_dir(os.path.dirname(args.curves_npy))
72+
if args.record_txt is not None:
73+
make_dir(os.path.dirname(args.record_txt))
74+
if args.record_xlsx is not None:
75+
make_dir(os.path.dirname(args.record_xlsx))
76+
if args.to_overwrite and not args.record_txt:
77+
warnings.warn("--to-overwrite only works with a valid --record-txt")
78+
return args
3079

31-
# 当前支持rgb_cod, rgb_sod, rgbd_sod
32-
data_type = "rgbd_sod"
33-
data_info = total_info[data_type]
3480

35-
# 存放输出文件的文件夹
36-
output_path = "../output"
37-
make_dir(output_path)
81+
def main():
82+
args = get_args()
3883

39-
# 包含所有数据集信息的字典
40-
dataset_info = get_datasets_info(
41-
datastes_info_json=data_info["dataset"],
42-
include_datasets=["NJUD"],
43-
# exclude_datasets=["LFSD"],
44-
)
45-
# 包含所有待比较模型结果的信息和绘图配置的字典
46-
drawing_info = get_methods_info(
47-
methods_info_json=data_info["method"],
48-
for_drawing=True,
49-
our_name="",
50-
include_methods=["CTMF_V16"],
51-
# exclude_methods=["UCNet_ABP", "UCNet_CVAE"],
52-
)
84+
# 包含所有数据集信息的字典
85+
datasets_info = get_datasets_info(
86+
datastes_info_json=args.dataset_json,
87+
include_datasets=args.include_datasets,
88+
exclude_datasets=args.exclude_datasets,
89+
)
90+
# 包含所有待比较模型结果的信息的字典
91+
methods_info = get_methods_info(
92+
methods_info_json=args.method_json,
93+
include_methods=args.include_methods,
94+
exclude_methods=args.exclude_methods,
95+
)
5396

54-
if __name__ == "__main__":
5597
# 确保多进程在windows上也可以正常使用
5698
cal_sod_matrics.cal_sod_matrics(
57-
data_type=data_type,
58-
to_append=True, # 是否保留之前的评估记录(针对txt_path文件有效)
59-
txt_path=os.path.join(output_path, f"{data_type}.txt"),
60-
xlsx_path=os.path.join(output_path, f"{data_type}.xlsx"),
61-
drawing_info=drawing_info,
62-
dataset_info=dataset_info,
63-
save_npy=True, # 是否将评估结果到npy文件中,该文件可用来绘制pr和fm曲线
64-
# 保存曲线指标数据的文件路径
65-
curves_npy_path=os.path.join(output_path, data_type + "_" + "curves.npy"),
66-
metrics_npy_path=os.path.join(output_path, data_type + "_" + "metrics.npy"),
67-
num_bits=3, # 评估结果保留的小数点后数据的位数
68-
num_workers=4,
69-
use_mp=False, # using multi-threading
99+
sheet_name="Results",
100+
to_append=not args.to_overwrite,
101+
txt_path=args.record_txt,
102+
xlsx_path=args.record_xlsx,
103+
methods_info=methods_info,
104+
datasets_info=datasets_info,
105+
curves_npy_path=args.curves_npy,
106+
metrics_npy_path=args.metric_npy,
107+
num_bits=args.num_bits,
108+
num_workers=args.num_workers,
109+
use_mp=False,
70110
)
111+
112+
113+
if __name__ == "__main__":
114+
main()

0 commit comments

Comments
 (0)