Skip to content

Commit 1f2f063

Browse files
committed
fix hovertemplate in sensitivity plot
1 parent 64bfaa4 commit 1f2f063

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

doubleml/_utils_plots.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ def _sensitivity_contour_plot(x,
6363
if benchmarks is not None:
6464
fig.add_trace(go.Scatter(x=benchmarks['cf_d'],
6565
y=benchmarks['cf_y'],
66+
customdata=benchmarks['value'].reshape(-1, 1),
6667
mode="markers+text",
6768
marker=dict(size=10, color='red', line=dict(width=2, color=text_col)),
68-
hovertemplate=hov_temp + f': {benchmarks["value"]}' + '</b>',
69+
hovertemplate=hov_temp + ': %{customdata[0]:.3f}' + '</b>',
6970
name="Benchmarks",
7071
textfont=dict(color=text_col, size=14),
7172
text=list(map(lambda s: "<b>" + s + "</b>", benchmarks['name'])),

doubleml/double_ml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,13 +1849,13 @@ def sensitivity_plot(self, idx_treatment=0, value='theta', include_scenario=True
18491849
benchmark_dict = copy.deepcopy(benchmarks)
18501850
if benchmarks is not None:
18511851
n_benchmarks = len(benchmarks['name'])
1852-
benchmark_values = n_benchmarks * [np.nan]
1852+
benchmark_values = np.full(shape=(n_benchmarks,), fill_value=np.nan)
18531853
for benchmark_idx in range(len(benchmarks['name'])):
18541854
sens_dict_bench = self._calc_sensitivity_analysis(cf_y=benchmarks['cf_y'][benchmark_idx],
18551855
cf_d=benchmarks['cf_y'][benchmark_idx],
18561856
rho=self.sensitivity_params['input']['rho'],
18571857
level=self.sensitivity_params['input']['level'])
1858-
benchmark_values[benchmark_idx] = round(sens_dict_bench[value][bound][idx_treatment], 3)
1858+
benchmark_values[benchmark_idx] = sens_dict_bench[value][bound][idx_treatment]
18591859
benchmark_dict['value'] = benchmark_values
18601860
fig = _sensitivity_contour_plot(x=cf_d_vec,
18611861
y=cf_y_vec,
@@ -1868,7 +1868,7 @@ def sensitivity_plot(self, idx_treatment=0, value='theta', include_scenario=True
18681868
benchmarks=benchmark_dict,
18691869
fill=fill)
18701870
return fig
1871-
1871+
18721872
def sensitivity_benchmark(self, benchmarking_set):
18731873
"""
18741874
Computes a benchmark for a given set of features.

0 commit comments

Comments
 (0)