Skip to content

Commit df617fd

Browse files
authored
Merge pull request #201 from DoubleML/s-bias-bounds
Add Sensitivity Analysis
2 parents 9bc6150 + ad636bd commit df617fd

37 files changed

+2408
-404
lines changed

doc/api/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ Dataset generators
5757
datasets.make_iivm_data
5858
datasets.make_plr_turrell2018
5959
datasets.make_pliv_multiway_cluster_CKMS2021
60+
datasets.make_confounded_plr_data
61+
datasets.make_confounded_irm_data
62+
6063

6164
Score mixin classes for double machine learning models
6265
------------------------------------------------------

doubleml/_utils.py

Lines changed: 89 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from sklearn.preprocessing import LabelEncoder
88
from sklearn.model_selection import KFold, GridSearchCV, RandomizedSearchCV
99
from sklearn.metrics import mean_squared_error
10-
from sklearn.utils.multiclass import type_of_target
1110

1211
from statsmodels.nonparametric.kde import KDEUnivariate
1312

1413
from joblib import Parallel, delayed
1514

15+
from ._utils_checks import _check_is_partition
16+
1617

1718
def _assure_2d_array(x):
1819
if x.ndim == 1:
@@ -40,63 +41,6 @@ def _get_cond_smpls_2d(smpls, bin_var1, bin_var2):
4041
return smpls_00, smpls_01, smpls_10, smpls_11
4142

4243

43-
def _check_is_partition(smpls, n_obs):
44-
test_indices = np.concatenate([test_index for _, test_index in smpls])
45-
if len(test_indices) != n_obs:
46-
return False
47-
hit = np.zeros(n_obs, dtype=bool)
48-
hit[test_indices] = True
49-
if not np.all(hit):
50-
return False
51-
return True
52-
53-
54-
def _check_all_smpls(all_smpls, n_obs, check_intersect=False):
55-
all_smpls_checked = list()
56-
for smpl in all_smpls:
57-
all_smpls_checked.append(_check_smpl_split(smpl, n_obs, check_intersect))
58-
return all_smpls_checked
59-
60-
61-
def _check_smpl_split(smpl, n_obs, check_intersect=False):
62-
smpl_checked = list()
63-
for tpl in smpl:
64-
smpl_checked.append(_check_smpl_split_tpl(tpl, n_obs, check_intersect))
65-
return smpl_checked
66-
67-
68-
def _check_smpl_split_tpl(tpl, n_obs, check_intersect=False):
69-
train_index = np.sort(np.array(tpl[0]))
70-
test_index = np.sort(np.array(tpl[1]))
71-
72-
if not issubclass(train_index.dtype.type, np.integer):
73-
raise TypeError('Invalid sample split. Train indices must be of type integer.')
74-
if not issubclass(test_index.dtype.type, np.integer):
75-
raise TypeError('Invalid sample split. Test indices must be of type integer.')
76-
77-
if check_intersect:
78-
if set(train_index) & set(test_index):
79-
raise ValueError('Invalid sample split. Intersection of train and test indices is not empty.')
80-
81-
if len(np.unique(train_index)) != len(train_index):
82-
raise ValueError('Invalid sample split. Train indices contain non-unique entries.')
83-
if len(np.unique(test_index)) != len(test_index):
84-
raise ValueError('Invalid sample split. Test indices contain non-unique entries.')
85-
86-
# we sort the indices above
87-
# if not np.all(np.diff(train_index) > 0):
88-
# raise NotImplementedError('Invalid sample split. Only sorted train indices are supported.')
89-
# if not np.all(np.diff(test_index) > 0):
90-
# raise NotImplementedError('Invalid sample split. Only sorted test indices are supported.')
91-
92-
if not set(train_index).issubset(range(n_obs)):
93-
raise ValueError('Invalid sample split. Train indices must be in [0, n_obs).')
94-
if not set(test_index).issubset(range(n_obs)):
95-
raise ValueError('Invalid sample split. Test indices must be in [0, n_obs).')
96-
97-
return train_index, test_index
98-
99-
10044
def _fit(estimator, x, y, train_index, idx=None):
10145
estimator.fit(x[train_index, :], y[train_index])
10246
return estimator, idx
@@ -238,13 +182,6 @@ def _draw_weights(method, n_rep_boot, n_obs):
238182
return weights
239183

240184

241-
def _check_finite_predictions(preds, learner, learner_name, smpls):
242-
test_indices = np.concatenate([test_index for _, test_index in smpls])
243-
if not np.all(np.isfinite(preds[test_indices])):
244-
raise ValueError(f'Predictions from learner {str(learner)} for {learner_name} are not finite.')
245-
return
246-
247-
248185
def _trimm(preds, trimming_rule, trimming_threshold):
249186
if trimming_rule == 'truncate':
250187
preds[preds < trimming_threshold] = trimming_threshold
@@ -261,14 +198,6 @@ def _normalize_ipw(propensity, treatment):
261198
return normalized_weights
262199

263200

264-
def _check_is_propensity(preds, learner, learner_name, smpls, eps=1e-12):
265-
test_indices = np.concatenate([test_index for _, test_index in smpls])
266-
if any((preds[test_indices] < eps) | (preds[test_indices] > 1 - eps)):
267-
warnings.warn(f'Propensity predictions from learner {str(learner)} for'
268-
f' {learner_name} are close to zero or one (eps={eps}).')
269-
return
270-
271-
272201
def _rmse(y_true, y_pred):
273202
subset = np.logical_not(np.isnan(y_true))
274203
rmse = mean_squared_error(y_true[subset], y_pred[subset], squared=False)
@@ -285,77 +214,6 @@ def _predict_zero_one_propensity(learner, X):
285214
return res
286215

287216

288-
def _check_contains_iv(obj_dml_data):
289-
if obj_dml_data.z_cols is not None:
290-
raise ValueError('Incompatible data. ' +
291-
' and '.join(obj_dml_data.z_cols) +
292-
' have been set as instrumental variable(s). '
293-
'To fit an local model see the documentation.')
294-
295-
296-
def _check_zero_one_treatment(obj_dml):
297-
one_treat = (obj_dml._dml_data.n_treat == 1)
298-
binary_treat = (type_of_target(obj_dml._dml_data.d) == 'binary')
299-
zero_one_treat = np.all((np.power(obj_dml._dml_data.d, 2) - obj_dml._dml_data.d) == 0)
300-
if not (one_treat & binary_treat & zero_one_treat):
301-
raise ValueError('Incompatible data. '
302-
f'To fit an {str(obj_dml.score)} model with DML '
303-
'exactly one binary variable with values 0 and 1 '
304-
'needs to be specified as treatment variable.')
305-
306-
307-
def _check_quantile(quantile):
308-
if not isinstance(quantile, float):
309-
raise TypeError('Quantile has to be a float. ' +
310-
f'Object of type {str(type(quantile))} passed.')
311-
312-
if (quantile <= 0) | (quantile >= 1):
313-
raise ValueError('Quantile has be between 0 or 1. ' +
314-
f'Quantile {str(quantile)} passed.')
315-
return
316-
317-
318-
def _check_treatment(treatment):
319-
if not isinstance(treatment, int):
320-
raise TypeError('Treatment indicator has to be an integer. ' +
321-
f'Object of type {str(type(treatment))} passed.')
322-
323-
if (treatment != 0) & (treatment != 1):
324-
raise ValueError('Treatment indicator has be either 0 or 1. ' +
325-
f'Treatment indicator {str(treatment)} passed.')
326-
return
327-
328-
329-
def _check_trimming(trimming_rule, trimming_threshold):
330-
valid_trimming_rule = ['truncate']
331-
if trimming_rule not in valid_trimming_rule:
332-
raise ValueError('Invalid trimming_rule ' + str(trimming_rule) + '. ' +
333-
'Valid trimming_rule ' + ' or '.join(valid_trimming_rule) + '.')
334-
if not isinstance(trimming_threshold, float):
335-
raise TypeError('trimming_threshold has to be a float. ' +
336-
f'Object of type {str(type(trimming_threshold))} passed.')
337-
if (trimming_threshold <= 0) | (trimming_threshold >= 0.5):
338-
raise ValueError('Invalid trimming_threshold ' + str(trimming_threshold) + '. ' +
339-
'trimming_threshold has to be between 0 and 0.5.')
340-
return
341-
342-
343-
def _check_score(score, valid_score, allow_callable=True):
344-
if isinstance(score, str):
345-
if score not in valid_score:
346-
raise ValueError('Invalid score ' + score + '. ' +
347-
'Valid score ' + ' or '.join(valid_score) + '.')
348-
else:
349-
if allow_callable:
350-
if not callable(score):
351-
raise TypeError('score should be either a string or a callable. '
352-
'%r was passed.' % score)
353-
else:
354-
raise TypeError('score should be a string. '
355-
'%r was passed.' % score)
356-
return
357-
358-
359217
def _get_bracket_guess(score, coef_start, coef_bounds):
360218
max_bracket_length = coef_bounds[1] - coef_bounds[0]
361219
b_guess = coef_bounds
@@ -388,3 +246,90 @@ def abs_ipw_score(theta):
388246
method='brent')
389247
ipw_est = res.x
390248
return ipw_est
249+
250+
251+
def _aggregate_coefs_and_ses(all_coefs, all_ses, var_scaling_factor):
252+
# aggregation is done over dimension 1, such that the coefs and ses have to be of shape (n_coefs, n_rep)
253+
n_rep = all_coefs.shape[1]
254+
coefs = np.median(all_coefs, 1)
255+
256+
xx = np.tile(coefs.reshape(-1, 1), n_rep)
257+
ses = np.sqrt(np.divide(np.median(np.multiply(np.power(all_ses, 2), var_scaling_factor) +
258+
np.power(all_coefs - xx, 2), 1), var_scaling_factor))
259+
260+
return coefs, ses
261+
262+
263+
def _var_est(psi, psi_deriv, apply_cross_fitting, smpls, is_cluster_data,
264+
cluster_vars=None, smpls_cluster=None, n_folds_per_cluster=None):
265+
266+
if not is_cluster_data:
267+
# psi and psi_deriv should be of shape (n_obs, ...)
268+
if apply_cross_fitting:
269+
var_scaling_factor = psi.shape[0]
270+
else:
271+
# In case of no-cross-fitting, the score function was only evaluated on the test data set
272+
test_index = smpls[0][1]
273+
psi_deriv = psi_deriv[test_index]
274+
psi = psi[test_index]
275+
var_scaling_factor = len(test_index)
276+
277+
J = np.mean(psi_deriv)
278+
gamma_hat = np.mean(np.square(psi))
279+
280+
else:
281+
assert cluster_vars is not None
282+
assert smpls_cluster is not None
283+
assert n_folds_per_cluster is not None
284+
n_folds = len(smpls)
285+
286+
# one cluster
287+
if cluster_vars.shape[1] == 1:
288+
first_cluster_var = cluster_vars[:, 0]
289+
clusters = np.unique(first_cluster_var)
290+
gamma_hat = 0
291+
j_hat = 0
292+
for i_fold in range(n_folds):
293+
test_inds = smpls[i_fold][1]
294+
test_cluster_inds = smpls_cluster[i_fold][1]
295+
I_k = test_cluster_inds[0]
296+
const = 1 / len(I_k)
297+
for cluster_value in I_k:
298+
ind_cluster = (first_cluster_var == cluster_value)
299+
gamma_hat += const * np.sum(np.outer(psi[ind_cluster], psi[ind_cluster]))
300+
j_hat += np.sum(psi_deriv[test_inds]) / len(I_k)
301+
302+
var_scaling_factor = len(clusters)
303+
J = np.divide(j_hat, n_folds_per_cluster)
304+
gamma_hat = np.divide(gamma_hat, n_folds_per_cluster)
305+
306+
else:
307+
assert cluster_vars.shape[1] == 2
308+
first_cluster_var = cluster_vars[:, 0]
309+
second_cluster_var = cluster_vars[:, 1]
310+
gamma_hat = 0
311+
j_hat = 0
312+
for i_fold in range(n_folds):
313+
test_inds = smpls[i_fold][1]
314+
test_cluster_inds = smpls_cluster[i_fold][1]
315+
I_k = test_cluster_inds[0]
316+
J_l = test_cluster_inds[1]
317+
const = np.divide(min(len(I_k), len(J_l)), (np.square(len(I_k) * len(J_l))))
318+
for cluster_value in I_k:
319+
ind_cluster = (first_cluster_var == cluster_value) & np.in1d(second_cluster_var, J_l)
320+
gamma_hat += const * np.sum(np.outer(psi[ind_cluster], psi[ind_cluster]))
321+
for cluster_value in J_l:
322+
ind_cluster = (second_cluster_var == cluster_value) & np.in1d(first_cluster_var, I_k)
323+
gamma_hat += const * np.sum(np.outer(psi[ind_cluster], psi[ind_cluster]))
324+
j_hat += np.sum(psi_deriv[test_inds]) / (len(I_k) * len(J_l))
325+
326+
n_first_clusters = len(np.unique(first_cluster_var))
327+
n_second_clusters = len(np.unique(second_cluster_var))
328+
var_scaling_factor = min(n_first_clusters, n_second_clusters)
329+
J = np.divide(j_hat, np.square(n_folds_per_cluster))
330+
gamma_hat = np.divide(gamma_hat, np.square(n_folds_per_cluster))
331+
332+
scaling = np.divide(1.0, np.multiply(var_scaling_factor, np.square(J)))
333+
sigma2_hat = np.multiply(scaling, gamma_hat)
334+
335+
return sigma2_hat, var_scaling_factor

0 commit comments

Comments
 (0)