@@ -44,21 +44,21 @@ def split(self, dataset: Dataset, frac_train: float = 0.8, frac_valid: float = 0
4444
4545 if seed is not None :
4646 np .random .seed (seed )
47-
47+
4848 if frac_valid == 0 :
4949 stratifier = IterativeStratification (n_splits = 2 , order = 1 ,
50- sample_distribution_per_fold = [frac_test , frac_train ])
50+ sample_distribution_per_fold = [frac_test , frac_train ], random_state = seed )
5151 train_indexes , test_indexes = next (stratifier .split (dataset .smiles , dataset .y ))
5252
5353 return train_indexes , [], test_indexes
5454 else :
5555 stratifier = IterativeStratification (n_splits = 2 , order = 1 , sample_distribution_per_fold = [frac_test ,
56- 1 - frac_test ])
56+ 1 - frac_test ], random_state = seed )
5757 train_indexes , test_indexes = next (stratifier .split (dataset .smiles , dataset .y ))
5858
5959 new_frac_train = frac_train / (1 - frac_test )
6060 stratifier = IterativeStratification (n_splits = 2 , order = 1 ,
61- sample_distribution_per_fold = [1 - new_frac_train , new_frac_train ])
61+ sample_distribution_per_fold = [1 - new_frac_train , new_frac_train ], random_state = seed )
6262
6363 new_train_indexes , valid_indexes = next (stratifier .split (dataset .smiles [train_indexes ],
6464 dataset .y [train_indexes ]))
0 commit comments