|
1 | 1 | # -*- coding: utf-8 -*-
|
| 2 | +import argparse |
2 | 3 | import os
|
| 4 | +import sys |
| 5 | +import warnings |
| 6 | + |
| 7 | +sys.path.append("..") |
3 | 8 |
|
4 | 9 | from metrics import cal_sod_matrics
|
5 | 10 | from utils.generate_info import get_datasets_info, get_methods_info
|
6 | 11 | from utils.misc import make_dir
|
7 | 12 |
|
8 |
| -""" |
9 |
| -Include: Fm Curve/PR Curves/MAE/(max/mean/weighted) Fmeasure/Smeasure/Emeasure |
10 | 13 |
|
11 |
| -NOTE: |
12 |
| -* Our method automatically calculates the intersection of `pre` and `gt`. |
13 |
| - But it needs to have uniform naming rules for `pre` and `gt`. |
14 |
| -""" |
| 14 | +def get_args(): |
| 15 | + parser = argparse.ArgumentParser( |
| 16 | + description="""Include: Fm Curve/PR Curves/MAE/(max/mean/weighted) Fmeasure/Smeasure/Emeasure |
| 17 | + NOTE: |
| 18 | + Our method automatically calculates the intersection of `pre` and `gt`. |
| 19 | + Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext` |
| 20 | + """, |
| 21 | + formatter_class=argparse.RawTextHelpFormatter, |
| 22 | + ) |
| 23 | + parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.") |
| 24 | + parser.add_argument("--method-json", required=True, type=str, help="Json file for methods.") |
| 25 | + parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.") |
| 26 | + parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.") |
| 27 | + parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.") |
| 28 | + parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.") |
| 29 | + parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.") |
| 30 | + parser.add_argument( |
| 31 | + "--include-methods", |
| 32 | + type=str, |
| 33 | + nargs="+", |
| 34 | + help="Names of only specific methods you want to evaluate.", |
| 35 | + ) |
| 36 | + parser.add_argument( |
| 37 | + "--exclude-methods", |
| 38 | + type=str, |
| 39 | + nargs="+", |
| 40 | + help="Names of some specific methods you do not want to evaluate.", |
| 41 | + ) |
| 42 | + parser.add_argument( |
| 43 | + "--include-datasets", |
| 44 | + type=str, |
| 45 | + nargs="+", |
| 46 | + help="Names of only specific datasets you want to evaluate.", |
| 47 | + ) |
| 48 | + parser.add_argument( |
| 49 | + "--exclude-datasets", |
| 50 | + type=str, |
| 51 | + nargs="+", |
| 52 | + help="Names of some specific datasets you do not want to evaluate.", |
| 53 | + ) |
| 54 | + parser.add_argument( |
| 55 | + "--num-workers", |
| 56 | + type=int, |
| 57 | + default=4, |
| 58 | + help="Number of workers for multi-threading or multi-processing. Default: 4", |
| 59 | + ) |
| 60 | + parser.add_argument( |
| 61 | + "--num-bits", |
| 62 | + type=int, |
| 63 | + default=3, |
| 64 | + help="Number of decimal places for showing results. Default: 3", |
| 65 | + ) |
| 66 | + args = parser.parse_args() |
15 | 67 |
|
16 |
| -total_info = dict( |
17 |
| - rgb_sod=dict( |
18 |
| - dataset="/home/lart/Coding/GIT/PySODEvalToolkit/configs/datasets/json/rgb_sod.json", |
19 |
| - method="/home/lart/Coding/GIT/PySODEvalToolkit/configs/methods/json/rgb_sod_methods.json", |
20 |
| - ), |
21 |
| - rgb_cod=dict( |
22 |
| - dataset="/home/lart/Coding/GIT/PySODEvalToolkit/configs/datasets/json/rgb_cod.json", |
23 |
| - method="/home/lart/Coding/GIT/PySODEvalToolkit/configs/methods/json/rgb_cod_methods.json", |
24 |
| - ), |
25 |
| - rgbd_sod=dict( |
26 |
| - dataset="/home/lart/Coding/GIT/PySODEvalToolkit/configs/datasets/json/rgbd_sod.json", |
27 |
| - method="/home/lart/Coding/GIT/PySODEvalToolkit/configs/methods/json/rgbd_sod_methods_ablation.json", |
28 |
| - ), |
29 |
| -) |
| 68 | + if args.metric_npy is not None: |
| 69 | + make_dir(os.path.dirname(args.metric_npy)) |
| 70 | + if args.curves_npy is not None: |
| 71 | + make_dir(os.path.dirname(args.curves_npy)) |
| 72 | + if args.record_txt is not None: |
| 73 | + make_dir(os.path.dirname(args.record_txt)) |
| 74 | + if args.record_xlsx is not None: |
| 75 | + make_dir(os.path.dirname(args.record_xlsx)) |
| 76 | + if args.to_overwrite and not args.record_txt: |
| 77 | + warnings.warn("--to-overwrite only works with a valid --record-txt") |
| 78 | + return args |
30 | 79 |
|
31 |
| -# 当前支持rgb_cod, rgb_sod, rgbd_sod |
32 |
| -data_type = "rgbd_sod" |
33 |
| -data_info = total_info[data_type] |
34 | 80 |
|
35 |
| -# 存放输出文件的文件夹 |
36 |
| -output_path = "../output" |
37 |
| -make_dir(output_path) |
| 81 | +def main(): |
| 82 | + args = get_args() |
38 | 83 |
|
39 |
| -# 包含所有数据集信息的字典 |
40 |
| -dataset_info = get_datasets_info( |
41 |
| - datastes_info_json=data_info["dataset"], |
42 |
| - include_datasets=["NJUD"], |
43 |
| - # exclude_datasets=["LFSD"], |
44 |
| -) |
45 |
| -# 包含所有待比较模型结果的信息和绘图配置的字典 |
46 |
| -drawing_info = get_methods_info( |
47 |
| - methods_info_json=data_info["method"], |
48 |
| - for_drawing=True, |
49 |
| - our_name="", |
50 |
| - include_methods=["CTMF_V16"], |
51 |
| - # exclude_methods=["UCNet_ABP", "UCNet_CVAE"], |
52 |
| -) |
| 84 | + # 包含所有数据集信息的字典 |
| 85 | + datasets_info = get_datasets_info( |
| 86 | + datastes_info_json=args.dataset_json, |
| 87 | + include_datasets=args.include_datasets, |
| 88 | + exclude_datasets=args.exclude_datasets, |
| 89 | + ) |
| 90 | + # 包含所有待比较模型结果的信息的字典 |
| 91 | + methods_info = get_methods_info( |
| 92 | + methods_info_json=args.method_json, |
| 93 | + include_methods=args.include_methods, |
| 94 | + exclude_methods=args.exclude_methods, |
| 95 | + ) |
53 | 96 |
|
54 |
| -if __name__ == "__main__": |
55 | 97 | # 确保多进程在windows上也可以正常使用
|
56 | 98 | cal_sod_matrics.cal_sod_matrics(
|
57 |
| - data_type=data_type, |
58 |
| - to_append=True, # 是否保留之前的评估记录(针对txt_path文件有效) |
59 |
| - txt_path=os.path.join(output_path, f"{data_type}.txt"), |
60 |
| - xlsx_path=os.path.join(output_path, f"{data_type}.xlsx"), |
61 |
| - drawing_info=drawing_info, |
62 |
| - dataset_info=dataset_info, |
63 |
| - save_npy=True, # 是否将评估结果到npy文件中,该文件可用来绘制pr和fm曲线 |
64 |
| - # 保存曲线指标数据的文件路径 |
65 |
| - curves_npy_path=os.path.join(output_path, data_type + "_" + "curves.npy"), |
66 |
| - metrics_npy_path=os.path.join(output_path, data_type + "_" + "metrics.npy"), |
67 |
| - num_bits=3, # 评估结果保留的小数点后数据的位数 |
68 |
| - num_workers=4, |
69 |
| - use_mp=False, # using multi-threading |
| 99 | + sheet_name="Results", |
| 100 | + to_append=not args.to_overwrite, |
| 101 | + txt_path=args.record_txt, |
| 102 | + xlsx_path=args.record_xlsx, |
| 103 | + methods_info=methods_info, |
| 104 | + datasets_info=datasets_info, |
| 105 | + curves_npy_path=args.curves_npy, |
| 106 | + metrics_npy_path=args.metric_npy, |
| 107 | + num_bits=args.num_bits, |
| 108 | + num_workers=args.num_workers, |
| 109 | + use_mp=False, |
70 | 110 | )
|
| 111 | + |
| 112 | + |
| 113 | +if __name__ == "__main__": |
| 114 | + main() |
0 commit comments