7
7
from sklearn .preprocessing import LabelEncoder
8
8
from sklearn .model_selection import KFold , GridSearchCV , RandomizedSearchCV
9
9
from sklearn .metrics import mean_squared_error
10
- from sklearn .utils .multiclass import type_of_target
11
10
12
11
from statsmodels .nonparametric .kde import KDEUnivariate
13
12
14
13
from joblib import Parallel , delayed
15
14
15
+ from ._utils_checks import _check_is_partition
16
+
16
17
17
18
def _assure_2d_array (x ):
18
19
if x .ndim == 1 :
@@ -40,63 +41,6 @@ def _get_cond_smpls_2d(smpls, bin_var1, bin_var2):
40
41
return smpls_00 , smpls_01 , smpls_10 , smpls_11
41
42
42
43
43
- def _check_is_partition (smpls , n_obs ):
44
- test_indices = np .concatenate ([test_index for _ , test_index in smpls ])
45
- if len (test_indices ) != n_obs :
46
- return False
47
- hit = np .zeros (n_obs , dtype = bool )
48
- hit [test_indices ] = True
49
- if not np .all (hit ):
50
- return False
51
- return True
52
-
53
-
54
- def _check_all_smpls (all_smpls , n_obs , check_intersect = False ):
55
- all_smpls_checked = list ()
56
- for smpl in all_smpls :
57
- all_smpls_checked .append (_check_smpl_split (smpl , n_obs , check_intersect ))
58
- return all_smpls_checked
59
-
60
-
61
- def _check_smpl_split (smpl , n_obs , check_intersect = False ):
62
- smpl_checked = list ()
63
- for tpl in smpl :
64
- smpl_checked .append (_check_smpl_split_tpl (tpl , n_obs , check_intersect ))
65
- return smpl_checked
66
-
67
-
68
- def _check_smpl_split_tpl (tpl , n_obs , check_intersect = False ):
69
- train_index = np .sort (np .array (tpl [0 ]))
70
- test_index = np .sort (np .array (tpl [1 ]))
71
-
72
- if not issubclass (train_index .dtype .type , np .integer ):
73
- raise TypeError ('Invalid sample split. Train indices must be of type integer.' )
74
- if not issubclass (test_index .dtype .type , np .integer ):
75
- raise TypeError ('Invalid sample split. Test indices must be of type integer.' )
76
-
77
- if check_intersect :
78
- if set (train_index ) & set (test_index ):
79
- raise ValueError ('Invalid sample split. Intersection of train and test indices is not empty.' )
80
-
81
- if len (np .unique (train_index )) != len (train_index ):
82
- raise ValueError ('Invalid sample split. Train indices contain non-unique entries.' )
83
- if len (np .unique (test_index )) != len (test_index ):
84
- raise ValueError ('Invalid sample split. Test indices contain non-unique entries.' )
85
-
86
- # we sort the indices above
87
- # if not np.all(np.diff(train_index) > 0):
88
- # raise NotImplementedError('Invalid sample split. Only sorted train indices are supported.')
89
- # if not np.all(np.diff(test_index) > 0):
90
- # raise NotImplementedError('Invalid sample split. Only sorted test indices are supported.')
91
-
92
- if not set (train_index ).issubset (range (n_obs )):
93
- raise ValueError ('Invalid sample split. Train indices must be in [0, n_obs).' )
94
- if not set (test_index ).issubset (range (n_obs )):
95
- raise ValueError ('Invalid sample split. Test indices must be in [0, n_obs).' )
96
-
97
- return train_index , test_index
98
-
99
-
100
44
def _fit (estimator , x , y , train_index , idx = None ):
101
45
estimator .fit (x [train_index , :], y [train_index ])
102
46
return estimator , idx
@@ -238,13 +182,6 @@ def _draw_weights(method, n_rep_boot, n_obs):
238
182
return weights
239
183
240
184
241
- def _check_finite_predictions (preds , learner , learner_name , smpls ):
242
- test_indices = np .concatenate ([test_index for _ , test_index in smpls ])
243
- if not np .all (np .isfinite (preds [test_indices ])):
244
- raise ValueError (f'Predictions from learner { str (learner )} for { learner_name } are not finite.' )
245
- return
246
-
247
-
248
185
def _trimm (preds , trimming_rule , trimming_threshold ):
249
186
if trimming_rule == 'truncate' :
250
187
preds [preds < trimming_threshold ] = trimming_threshold
@@ -261,14 +198,6 @@ def _normalize_ipw(propensity, treatment):
261
198
return normalized_weights
262
199
263
200
264
- def _check_is_propensity (preds , learner , learner_name , smpls , eps = 1e-12 ):
265
- test_indices = np .concatenate ([test_index for _ , test_index in smpls ])
266
- if any ((preds [test_indices ] < eps ) | (preds [test_indices ] > 1 - eps )):
267
- warnings .warn (f'Propensity predictions from learner { str (learner )} for'
268
- f' { learner_name } are close to zero or one (eps={ eps } ).' )
269
- return
270
-
271
-
272
201
def _rmse (y_true , y_pred ):
273
202
subset = np .logical_not (np .isnan (y_true ))
274
203
rmse = mean_squared_error (y_true [subset ], y_pred [subset ], squared = False )
@@ -285,77 +214,6 @@ def _predict_zero_one_propensity(learner, X):
285
214
return res
286
215
287
216
288
- def _check_contains_iv (obj_dml_data ):
289
- if obj_dml_data .z_cols is not None :
290
- raise ValueError ('Incompatible data. ' +
291
- ' and ' .join (obj_dml_data .z_cols ) +
292
- ' have been set as instrumental variable(s). '
293
- 'To fit an local model see the documentation.' )
294
-
295
-
296
- def _check_zero_one_treatment (obj_dml ):
297
- one_treat = (obj_dml ._dml_data .n_treat == 1 )
298
- binary_treat = (type_of_target (obj_dml ._dml_data .d ) == 'binary' )
299
- zero_one_treat = np .all ((np .power (obj_dml ._dml_data .d , 2 ) - obj_dml ._dml_data .d ) == 0 )
300
- if not (one_treat & binary_treat & zero_one_treat ):
301
- raise ValueError ('Incompatible data. '
302
- f'To fit an { str (obj_dml .score )} model with DML '
303
- 'exactly one binary variable with values 0 and 1 '
304
- 'needs to be specified as treatment variable.' )
305
-
306
-
307
- def _check_quantile (quantile ):
308
- if not isinstance (quantile , float ):
309
- raise TypeError ('Quantile has to be a float. ' +
310
- f'Object of type { str (type (quantile ))} passed.' )
311
-
312
- if (quantile <= 0 ) | (quantile >= 1 ):
313
- raise ValueError ('Quantile has be between 0 or 1. ' +
314
- f'Quantile { str (quantile )} passed.' )
315
- return
316
-
317
-
318
- def _check_treatment (treatment ):
319
- if not isinstance (treatment , int ):
320
- raise TypeError ('Treatment indicator has to be an integer. ' +
321
- f'Object of type { str (type (treatment ))} passed.' )
322
-
323
- if (treatment != 0 ) & (treatment != 1 ):
324
- raise ValueError ('Treatment indicator has be either 0 or 1. ' +
325
- f'Treatment indicator { str (treatment )} passed.' )
326
- return
327
-
328
-
329
- def _check_trimming (trimming_rule , trimming_threshold ):
330
- valid_trimming_rule = ['truncate' ]
331
- if trimming_rule not in valid_trimming_rule :
332
- raise ValueError ('Invalid trimming_rule ' + str (trimming_rule ) + '. ' +
333
- 'Valid trimming_rule ' + ' or ' .join (valid_trimming_rule ) + '.' )
334
- if not isinstance (trimming_threshold , float ):
335
- raise TypeError ('trimming_threshold has to be a float. ' +
336
- f'Object of type { str (type (trimming_threshold ))} passed.' )
337
- if (trimming_threshold <= 0 ) | (trimming_threshold >= 0.5 ):
338
- raise ValueError ('Invalid trimming_threshold ' + str (trimming_threshold ) + '. ' +
339
- 'trimming_threshold has to be between 0 and 0.5.' )
340
- return
341
-
342
-
343
- def _check_score (score , valid_score , allow_callable = True ):
344
- if isinstance (score , str ):
345
- if score not in valid_score :
346
- raise ValueError ('Invalid score ' + score + '. ' +
347
- 'Valid score ' + ' or ' .join (valid_score ) + '.' )
348
- else :
349
- if allow_callable :
350
- if not callable (score ):
351
- raise TypeError ('score should be either a string or a callable. '
352
- '%r was passed.' % score )
353
- else :
354
- raise TypeError ('score should be a string. '
355
- '%r was passed.' % score )
356
- return
357
-
358
-
359
217
def _get_bracket_guess (score , coef_start , coef_bounds ):
360
218
max_bracket_length = coef_bounds [1 ] - coef_bounds [0 ]
361
219
b_guess = coef_bounds
@@ -388,3 +246,90 @@ def abs_ipw_score(theta):
388
246
method = 'brent' )
389
247
ipw_est = res .x
390
248
return ipw_est
249
+
250
+
251
+ def _aggregate_coefs_and_ses (all_coefs , all_ses , var_scaling_factor ):
252
+ # aggregation is done over dimension 1, such that the coefs and ses have to be of shape (n_coefs, n_rep)
253
+ n_rep = all_coefs .shape [1 ]
254
+ coefs = np .median (all_coefs , 1 )
255
+
256
+ xx = np .tile (coefs .reshape (- 1 , 1 ), n_rep )
257
+ ses = np .sqrt (np .divide (np .median (np .multiply (np .power (all_ses , 2 ), var_scaling_factor ) +
258
+ np .power (all_coefs - xx , 2 ), 1 ), var_scaling_factor ))
259
+
260
+ return coefs , ses
261
+
262
+
263
+ def _var_est (psi , psi_deriv , apply_cross_fitting , smpls , is_cluster_data ,
264
+ cluster_vars = None , smpls_cluster = None , n_folds_per_cluster = None ):
265
+
266
+ if not is_cluster_data :
267
+ # psi and psi_deriv should be of shape (n_obs, ...)
268
+ if apply_cross_fitting :
269
+ var_scaling_factor = psi .shape [0 ]
270
+ else :
271
+ # In case of no-cross-fitting, the score function was only evaluated on the test data set
272
+ test_index = smpls [0 ][1 ]
273
+ psi_deriv = psi_deriv [test_index ]
274
+ psi = psi [test_index ]
275
+ var_scaling_factor = len (test_index )
276
+
277
+ J = np .mean (psi_deriv )
278
+ gamma_hat = np .mean (np .square (psi ))
279
+
280
+ else :
281
+ assert cluster_vars is not None
282
+ assert smpls_cluster is not None
283
+ assert n_folds_per_cluster is not None
284
+ n_folds = len (smpls )
285
+
286
+ # one cluster
287
+ if cluster_vars .shape [1 ] == 1 :
288
+ first_cluster_var = cluster_vars [:, 0 ]
289
+ clusters = np .unique (first_cluster_var )
290
+ gamma_hat = 0
291
+ j_hat = 0
292
+ for i_fold in range (n_folds ):
293
+ test_inds = smpls [i_fold ][1 ]
294
+ test_cluster_inds = smpls_cluster [i_fold ][1 ]
295
+ I_k = test_cluster_inds [0 ]
296
+ const = 1 / len (I_k )
297
+ for cluster_value in I_k :
298
+ ind_cluster = (first_cluster_var == cluster_value )
299
+ gamma_hat += const * np .sum (np .outer (psi [ind_cluster ], psi [ind_cluster ]))
300
+ j_hat += np .sum (psi_deriv [test_inds ]) / len (I_k )
301
+
302
+ var_scaling_factor = len (clusters )
303
+ J = np .divide (j_hat , n_folds_per_cluster )
304
+ gamma_hat = np .divide (gamma_hat , n_folds_per_cluster )
305
+
306
+ else :
307
+ assert cluster_vars .shape [1 ] == 2
308
+ first_cluster_var = cluster_vars [:, 0 ]
309
+ second_cluster_var = cluster_vars [:, 1 ]
310
+ gamma_hat = 0
311
+ j_hat = 0
312
+ for i_fold in range (n_folds ):
313
+ test_inds = smpls [i_fold ][1 ]
314
+ test_cluster_inds = smpls_cluster [i_fold ][1 ]
315
+ I_k = test_cluster_inds [0 ]
316
+ J_l = test_cluster_inds [1 ]
317
+ const = np .divide (min (len (I_k ), len (J_l )), (np .square (len (I_k ) * len (J_l ))))
318
+ for cluster_value in I_k :
319
+ ind_cluster = (first_cluster_var == cluster_value ) & np .in1d (second_cluster_var , J_l )
320
+ gamma_hat += const * np .sum (np .outer (psi [ind_cluster ], psi [ind_cluster ]))
321
+ for cluster_value in J_l :
322
+ ind_cluster = (second_cluster_var == cluster_value ) & np .in1d (first_cluster_var , I_k )
323
+ gamma_hat += const * np .sum (np .outer (psi [ind_cluster ], psi [ind_cluster ]))
324
+ j_hat += np .sum (psi_deriv [test_inds ]) / (len (I_k ) * len (J_l ))
325
+
326
+ n_first_clusters = len (np .unique (first_cluster_var ))
327
+ n_second_clusters = len (np .unique (second_cluster_var ))
328
+ var_scaling_factor = min (n_first_clusters , n_second_clusters )
329
+ J = np .divide (j_hat , np .square (n_folds_per_cluster ))
330
+ gamma_hat = np .divide (gamma_hat , np .square (n_folds_per_cluster ))
331
+
332
+ scaling = np .divide (1.0 , np .multiply (var_scaling_factor , np .square (J )))
333
+ sigma2_hat = np .multiply (scaling , gamma_hat )
334
+
335
+ return sigma2_hat , var_scaling_factor
0 commit comments