Skip to content

Commit df05a11

Browse files
committed
Add bar chart for feature importance per position to CPPPlot().feature_map
1 parent 0ac72bd commit df05a11

File tree

14 files changed

+234
-101
lines changed

14 files changed

+234
-101
lines changed

aaanalysis/_utils/check_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def check_array_like(name=None, val=None, dtype=None, ensure_2d=False, allow_nan
6262
check_type.check_str(name="dtype", val=dtype, accept_none=True)
6363
valid_dtypes = ["numeric", "int", "float", "bool", None]
6464
if dtype not in valid_dtypes:
65-
str_error = add_str(str_error=f"'dtype' should be one of the following: {valid_dtypes}", str_add=str_add)
65+
str_error = add_str(str_error=f"'dtype' should be one of: {valid_dtypes}", str_add=str_add)
6666
raise ValueError(str_error)
6767
dict_expected_dtype = {"numeric": "numeric", "int": "int64", "float": "float64", "bool": "bool"}
6868
expected_dtype = dict_expected_dtype[dtype] if dtype is not None else None

aaanalysis/_utils/check_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def check_str_options(name=None, val=None, accept_none=False, list_str_options=N
7777
"""Check if valid string option"""
7878
if accept_none and val is None:
7979
return None # Skip test
80-
str_add = add_str(str_error=f"'{name}' ({val}) should be one of following: {list_str_options}")
80+
str_add = add_str(str_error=f"'{name}' ({val}) should be one of: {list_str_options}")
8181
check_str(name=name, val=val, accept_none=accept_none, str_add=str_add)
8282
if val not in list_str_options:
8383
raise ValueError(str_add)

aaanalysis/_utils/plotting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _check_hatches(marker=None, hatch=None, list_cat=None):
9696
wrong_hatch = [x for x in hatch if x not in valid_hatches]
9797
if len(wrong_hatch) != 0:
9898
raise ValueError(
99-
f"'hatch' contains wrong values ('{wrong_hatch}')! Should be one of following: {valid_hatches}")
99+
f"'hatch' contains wrong values ('{wrong_hatch}')! Should be one of: {valid_hatches}")
100100
if len(hatch) != len(list_cat):
101101
raise ValueError(f"Length must match of 'hatch' ({hatch}) and categories ({list_cat}).") # Check if hatch can be chosen
102102
# Warn for parameter conflicts
@@ -117,7 +117,7 @@ def _check_marker(marker=None, list_cat=None, lw=0):
117117
if isinstance(marker, list):
118118
wrong_markers = [x for x in marker if x not in valid_markers]
119119
if len(wrong_markers) != 0:
120-
raise ValueError(f"'marker' contains wrong values ('{wrong_markers}'). Should be one of following: {valid_markers}")
120+
raise ValueError(f"'marker' contains wrong values ('{wrong_markers}'). Should be one of: {valid_markers}")
121121
if len(marker) != len(list_cat):
122122
raise ValueError(f"Length must match of 'marker' ({marker}) and categories ({list_cat}).")
123123
# Warn for parameter conflicts
@@ -155,7 +155,7 @@ def _check_linestyle(linestyle=None, list_cat=None, marker=None):
155155
wrong_mls = [x for x in linestyle if x not in valid_mls]
156156
if len(wrong_mls) != 0:
157157
raise ValueError(
158-
f"'marker_linestyle' contains wrong values ('{wrong_mls}')! Should be one of following: {valid_mls}")
158+
f"'marker_linestyle' contains wrong values ('{wrong_mls}')! Should be one of: {valid_mls}")
159159
if len(linestyle) != len(list_cat):
160160
raise ValueError(f"Length must match of 'marker_linestyle' ({linestyle}) and categories ({list_cat}).")
161161
# Check if marker_linestyle is conflicting with other settings

aaanalysis/_utils/utils_plot_elements.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
"""
44
import seaborn as sns
55
import matplotlib.patches as mpatches
6-
import matplotlib.pyplot as plt
76
import matplotlib.ticker as mticker
87
import matplotlib as mpl
8+
import matplotlib.pyplot as plt
9+
import numpy as np
910

1011

1112
# Helper functions
@@ -177,15 +178,53 @@ def adjust_spine_to_middle(ax=None):
177178
return ax
178179

179180

180-
def x_ticks_0(ax):
181+
def ticks_0(ax, show_zero=True, axis="x", precision=2):
181182
"""Apply custom formatting for x-axis ticks."""
182-
def custom_x_ticks(x, pos):
183+
if axis not in ["x", "y"]:
184+
raise ValueError("'axis' should be 'x' or 'y'")
185+
def custom_ticks(x, pos):
183186
"""Format x-axis ticks."""
184-
if x % 1 == 0: # Check if number is an integer
187+
if x == 0 and not show_zero:
188+
return ''
189+
elif x % 1 == 0: # Check if number is an integer
190+
return f'{int(x)}' # Format as integer
191+
else:
192+
# Format as float with two decimal places
193+
return f'{x:.{precision}f}'
194+
if axis == "x":
195+
ax.xaxis.set_major_formatter(mticker.FuncFormatter(custom_ticks))
196+
else:
197+
ax.yaxis.set_major_formatter(mticker.FuncFormatter(custom_ticks))
198+
199+
200+
def ticks_0(ax, show_zero=True, show_only_max=False, axis="x", precision=2):
201+
"""Apply custom formatting for axis ticks and ensure max value is shown."""
202+
if axis not in ["x", "y"]:
203+
raise ValueError("'axis' should be 'x' or 'y'")
204+
# Format tick labels
205+
def custom_ticks(x, pos):
206+
"""Format axis ticks."""
207+
if x == 0 and not show_zero:
208+
return ''
209+
elif x % 1 == 0: # Check if number is an integer
185210
return f'{int(x)}' # Format as integer
186211
else:
187-
return f'{x:.2f}' # Format as float with two decimal places
188-
ax.xaxis.set_major_formatter(mticker.FuncFormatter(custom_x_ticks))
212+
# Format as float with specified precision
213+
return f'{x:.{precision}f}'
214+
215+
# Get the current axis object
216+
axis_obj = ax.xaxis if axis == "x" else ax.yaxis
217+
axis_obj.set_major_formatter(mticker.FuncFormatter(custom_ticks))
218+
219+
# Get the current limits
220+
if show_only_max:
221+
vmax = ax.get_xlim()[1] if axis == "x" else ax.get_ylim()[1]
222+
max_val = int(np.ceil(vmax))
223+
new_ticks = [max_val]
224+
if axis == "x":
225+
ax.set_xticks(new_ticks)
226+
else:
227+
ax.set_yticks(new_ticks)
189228

190229

191230
def adjust_tuple_elements(tuple_in=None, tuple_default=None):

aaanalysis/data_handling/_load_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def check_name(name=None):
1313
"""Check provided names of dataset"""
1414
if name not in LIST_DATASETS_WITH_FEATURES:
15-
raise ValueError(f"'name' should be one of the following: {LIST_DATASETS_WITH_FEATURES}")
15+
raise ValueError(f"'name' should be one of: {LIST_DATASETS_WITH_FEATURES}")
1616

1717

1818
# II Main Functions

aaanalysis/feature_engineering/_aaclust_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def check_method(method=None):
6868
"""Validate the method parameter against a list of valid hierarchical clustering methods"""
6969
valid_methods = ["single", "complete", "average", "weighted", "centroid", "median", "ward"]
7070
if method not in valid_methods:
71-
raise ValueError(f"'method' ({method}) should be one of following: {valid_methods}")
71+
raise ValueError(f"'method' ({method}) should be one of: {valid_methods}")
7272

7373

7474
def check_match_df_corr_clust_x(df_corr=None, cluster_x=None):

aaanalysis/feature_engineering/_backend/aaclust/aaclust_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def plot_eval(df_eval=None, dict_xlims=None, figsize=None, colors=None):
7171
ax.set_title("Quality measures", weight="bold")
7272
ax.tick_params(axis='y', which='both', left=False)
7373
if i != 0:
74-
ut.x_ticks_0(ax=ax)
74+
ut.ticks_0(ax=ax)
7575
# Set xlims
7676
if dict_xlims is not None:
7777
for i in dict_xlims:

0 commit comments

Comments
 (0)