Skip to content

Commit 8c39643

Browse files
committed
Update CPPPlot().feature_map
1 parent df05a11 commit 8c39643

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

aaanalysis/feature_engineering/_backend/cpp/cpp_plot_feature_map.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def add_feat_importance_legend(ax=None,
160160

161161

162162
# II Main Functions
163-
# TODO in frontend: test (add_imp_bar_top, imp_bar_label_type)
164163
# TODO stacked bar charts for SHAP (later)
165164
def plot_feature_map(df_feat=None, df_cat=None,
166165
col_cat="subcategory", col_val="mean_dif", col_imp="feat_importance",
@@ -313,7 +312,7 @@ def plot_feature_map(df_feat=None, df_cat=None,
313312
fontsize_annotations=fs_annotations)
314313
if add_imp_bar_top:
315314
ax_empty.axis("off")
316-
plt.sca(ax_hm)
317315
plt.subplots_adjust(wspace=0.0, hspace=0.0)
318-
ax = [ax_hm, ax_br, ax_bt] if add_imp_bar_top else [ax_hm, ax_br]
316+
plt.sca(ax_hm)
317+
ax = ax_hm
319318
return fig, ax

tests/unit/cpp_plot_tests/test_cpp_plot_feature_map.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,15 @@ def test_valid_imp_marker_sizes(self):
318318
imp_marker_sizes=valid_imp_marker_sizes)
319319
assert isinstance(fig, plt.Figure) and isinstance(ax, plt.Axes)
320320
plt.close()
321-
322-
321+
322+
def test_valid_add_imp_bar_top(self):
323+
cpp_plot = aa.CPPPlot()
324+
df_feat = get_df_feat()
325+
for add_imp_bar_top in [True, False]:
326+
fig, ax = cpp_plot.feature_map(df_feat=df_feat, add_imp_bar_top=add_imp_bar_top)
327+
assert isinstance(fig, plt.Figure) and isinstance(ax, plt.Axes)
328+
plt.close()
329+
323330
def test_valid_imp_bar_th(self):
324331
cpp_plot = aa.CPPPlot()
325332
df_feat = get_df_feat()
@@ -328,6 +335,14 @@ def test_valid_imp_bar_th(self):
328335
assert isinstance(fig, plt.Figure) and isinstance(ax, plt.Axes)
329336
plt.close()
330337

338+
def test_valid_imp_bar_label_type(self):
339+
cpp_plot = aa.CPPPlot()
340+
df_feat = get_df_feat()
341+
for imp_bar_label_type in ["short", "long", None]:
342+
fig, ax = cpp_plot.feature_map(df_feat=df_feat, imp_bar_label_type=imp_bar_label_type)
343+
assert isinstance(fig, plt.Figure) and isinstance(ax, plt.Axes)
344+
plt.close()
345+
331346
@settings(max_examples=3, deadline=5000)
332347
@given(xtick_size=st.floats(min_value=8.0, max_value=14.0), xtick_width=st.floats(min_value=0.5, max_value=2.0),
333348
xtick_length=st.floats(min_value=3.0, max_value=10.0))
@@ -656,6 +671,14 @@ def test_invalid_imp_marker_sizes(self):
656671
cpp_plot.feature_map(df_feat=df_feat, imp_marker_sizes=invalid_imp_marker_sizes)
657672
plt.close()
658673

674+
def test_invalid_add_imp_bar_top(self):
675+
cpp_plot = aa.CPPPlot()
676+
df_feat = get_df_feat()
677+
for add_imp_bar_top in [None, "adsf", 123, pd.DataFrame, {}]:
678+
with pytest.raises(ValueError):
679+
cpp_plot.feature_map(df_feat=df_feat, add_imp_bar_top=add_imp_bar_top)
680+
plt.close()
681+
659682
def test_invalid_imp_bar_th(self):
660683
cpp_plot = aa.CPPPlot()
661684
df_feat = get_df_feat()
@@ -664,6 +687,14 @@ def test_invalid_imp_bar_th(self):
664687
cpp_plot.feature_map(df_feat=df_feat, imp_bar_th=invalid_imp_bar_th)
665688
plt.close()
666689

690+
def test_invalid_imp_bar_label_type(self):
691+
cpp_plot = aa.CPPPlot()
692+
df_feat = get_df_feat()
693+
for imp_bar_label_type in ["adsf", 123, pd.DataFrame, {}]:
694+
with pytest.raises(ValueError):
695+
cpp_plot.feature_map(df_feat=df_feat, imp_bar_label_type=imp_bar_label_type)
696+
plt.close()
697+
667698
@settings(max_examples=3, deadline=5000)
668699
@given(xtick_size=st.just(-1), xtick_width=st.just(-1), xtick_length=st.just(-1))
669700
def test_invalid_tick_styling(self, xtick_size, xtick_width, xtick_length):
@@ -734,4 +765,4 @@ def test_complex_negative_positive(self):
734765
xtick_width=-1, # Invalid xtick_width
735766
xtick_length=-1, # Invalid xtick_length
736767
**args_seq)
737-
plt.close()
768+
plt.close()

tests/unit/cpp_plot_tests/test_cpp_plot_heatmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def adjust_vmin_vmax(vmin=None, vmax=None):
4343
return vmin, vmax
4444

4545

46-
4746
def get_args_seq(n=0):
4847
aa.options["verbose"] = False
4948
df_seq = aa.load_dataset(name="DOM_GSEC", n=N_SEQ)

0 commit comments

Comments
 (0)