@@ -52,7 +52,7 @@ def check_dataset(
5252
5353
5454def check_points (dataset , split_ratios , split_names , i : int ):
55- sorted_points = sorted ([dataset .weights [name ] for name in dataset .names ], key = lambda x : x [1 ], reverse = True )
55+ sorted_points = sorted ([( name , dataset .weights [name ]) for name in dataset .names ], key = lambda x : x [1 ], reverse = True )
5656 total_weight = sum (x [1 ] for x in sorted_points [i :])
5757 if [x [1 ] / total_weight for x in sorted_points [i :len (split_ratios )]] <= sorted (split_ratios , reverse = True ):
5858 return None
@@ -64,7 +64,7 @@ def check_points(dataset, split_ratios, split_names, i: int):
6464
6565
6666def check_clusters (dataset , split_ratios , split_names , strategy : Literal ["break" , "assign" ], linkage : Literal ["average" , "single" , "complete" ], i : int ):
67- sorted_clusters = sorted ([dataset .cluster_weights [name ] for name in dataset .cluster_names ], key = lambda x : x [1 ], reverse = True )
67+ sorted_clusters = sorted ([( name , dataset .cluster_weights [name ]) for name in dataset .cluster_names ], key = lambda x : x [1 ], reverse = True )
6868 total_weight = sum (x [1 ] for x in sorted_clusters [i :])
6969 if [x [1 ] / total_weight for x in sorted_clusters [i :len (split_ratios )]] <= sorted (split_ratios , reverse = True ):
7070 return None
@@ -94,9 +94,6 @@ def assign_cluster(dataset: DataSet, cluster_name: Any, split_ratios, split_name
9494 if dataset .cluster_map [n ] == cluster_name :
9595 name_split_map [n ] = split_name
9696 dataset .cluster_names = dataset .cluster_names [:cluster_index ] + dataset .cluster_names [cluster_index + 1 :]
97- del dataset .cluster_weights [cluster_name ]
98- if dataset .cluster_stratification is not None :
99- del dataset .cluster_stratification [cluster_name ]
10097 if dataset .cluster_similarity is not None :
10198 dataset .cluster_similarity = np .delete (dataset .cluster_similarity , cluster_index , axis = 0 )
10299 dataset .cluster_similarity = np .delete (dataset .cluster_similarity , cluster_index , axis = 1 )
@@ -108,9 +105,6 @@ def assign_cluster(dataset: DataSet, cluster_name: Any, split_ratios, split_name
108105 cluster_split_map = {}
109106 name_index = dataset .names .index (cluster_name )
110107 dataset .names = dataset .names [:name_index ] + dataset .names [name_index + 1 :]
111- del dataset .weights [cluster_name ]
112- if dataset .stratification is not None :
113- del dataset .stratification [cluster_name ]
114108 if dataset .similarity is not None :
115109 dataset .similarity = np .delete (dataset .similarity , name_index , axis = 0 )
116110 dataset .similarity = np .delete (dataset .similarity , name_index , axis = 1 )
@@ -179,8 +173,8 @@ def break_cluster(dataset: DataSet, cluster_name: Any, split_ratio: float, linka
179173
180174 if dataset .stratification is not None and len (dataset .classes ) > 1 :
181175 cluster_stratification = defaultdict (lambda : np .zeros (len (dataset .classes )))
182- for key , value in dataset .cluster_map .items ():
183- cluster_stratification [value ] += dataset .stratification [key ]
176+ for name in dataset . names : # key, value in dataset.cluster_map.items():
177+ cluster_stratification [dataset . cluster_map [ name ]] += dataset .stratification [name ]
184178 else :
185179 cluster_stratification = None
186180
0 commit comments