9
9
10
10
11
11
def draw_curves (
12
- for_pr : bool = True ,
12
+ mode : str ,
13
13
axes_setting : dict = None ,
14
14
curves_npy_path : list = None ,
15
15
row_num : int = 1 ,
@@ -20,14 +20,13 @@ def draw_curves(
20
20
ncol_of_legend : int = 1 ,
21
21
separated_legend : bool = False ,
22
22
sharey : bool = False ,
23
- line_styles = ("-" , "--" ),
24
23
line_width = 3 ,
25
24
save_name = None ,
26
25
):
27
26
"""A better curve painter!
28
27
29
28
Args:
30
- for_pr (bool, optional ): Plot for PR curves or FM curves. Defaults to True .
29
+ mode (str ): `pr` for PR curves, `fm` for F-measure curves, and `em' for E-measure curves .
31
30
axes_setting (dict, optional): Setting for axes. Defaults to None.
32
31
curves_npy_path (list, optional): Paths of curve npy files. Defaults to None.
33
32
row_num (int, optional): Number of rows. Defaults to 1.
@@ -38,11 +37,10 @@ def draw_curves(
38
37
ncol_of_legend (int, optional): Number of columns for the legend. Defaults to 1.
39
38
separated_legend (bool, optional): Use the separated legend. Defaults to False.
40
39
sharey (bool, optional): Use a shared y-axis. Defaults to False.
41
- line_styles (tuple, optional): Styles of lines. Defaults to ("-", "--").
42
40
line_width (int, optional): Width of lines. Defaults to 3.
43
41
save_name (str, optional): Name or path (without the extension format). Defaults to None.
44
42
"""
45
- mode = "pr" if for_pr else "fm"
43
+ assert mode in [ "pr" , "fm" , "em" ]
46
44
save_name = save_name or mode
47
45
mode_axes_setting = axes_setting [mode ]
48
46
@@ -97,23 +95,36 @@ def draw_curves(
97
95
# assert len(our_methods) <= len(line_styles)
98
96
else :
99
97
our_methods = []
98
+ num_our_methods = len (our_methods )
100
99
101
- # Give each method a unique color.
100
+ # Give each method a unique color and style .
102
101
color_table = sorted (
103
102
[
104
103
color
105
104
for name , color in colors .cnames .items ()
106
105
if name not in ["red" , "white" ] or not name .startswith ("light" ) or "gray" in name
107
106
]
108
107
)
108
+ style_table = ["-" , "--" , "-." , ":" , "." ]
109
+
109
110
unique_method_settings = OrderedDict ()
110
111
for i , method_name in enumerate (target_unique_method_names ):
112
+ if i < num_our_methods :
113
+ line_color = "red"
114
+ line_style = style_table [i % len (style_table )]
115
+ else :
116
+ other_idx = i - num_our_methods
117
+ line_color = color_table [other_idx ]
118
+ line_style = style_table [other_idx % 2 ]
119
+
111
120
unique_method_settings [method_name ] = {
112
- "line_color" : "red" if i < len ( our_methods ) else color_table [ i ] ,
121
+ "line_color" : line_color ,
113
122
"line_label" : method_aliases .get (method_name , method_name ),
114
- "line_style" : line_styles [ i % len ( line_styles )] ,
123
+ "line_style" : line_style ,
115
124
"line_width" : line_width ,
116
125
}
126
+ # ensure that our methods are drawn last to avoid being overwritten by other methods
127
+ target_unique_method_names .reverse ()
117
128
118
129
curve_drawer = CurveDrawer (
119
130
row_num = row_num ,
@@ -135,9 +146,13 @@ def draw_curves(
135
146
y_ticks = y_ticks ,
136
147
)
137
148
138
- for method_name , method_setting in unique_method_settings .items ():
149
+ for method_name in target_unique_method_names :
150
+ method_setting = unique_method_settings [method_name ]
151
+
139
152
if method_name not in dataset_results :
140
- raise KeyError (f"{ method_name } not in { sorted (dataset_results .keys ())} " )
153
+ print (f"{ method_name } will be skipped for { dataset_name } !" )
154
+ continue
155
+
141
156
method_results = dataset_results [method_name ]
142
157
if mode == "pr" :
143
158
y_data = method_results .get ("p" )
0 commit comments