Skip to content

Commit becdd19

Browse files
Minor mistakes with previous fix fixed
1 parent f1d5fe9 commit becdd19

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

datasail/solver/overflow.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def check_dataset(
5252

5353

5454
def check_points(dataset, split_ratios, split_names, i: int):
55-
sorted_points = sorted([dataset.weights[name] for name in dataset.names], key=lambda x: x[1], reverse=True)
55+
sorted_points = sorted([(name, dataset.weights[name]) for name in dataset.names], key=lambda x: x[1], reverse=True)
5656
total_weight = sum(x[1] for x in sorted_points[i:])
5757
if [x[1] / total_weight for x in sorted_points[i:len(split_ratios)]] <= sorted(split_ratios, reverse=True):
5858
return None
@@ -64,7 +64,7 @@ def check_points(dataset, split_ratios, split_names, i: int):
6464

6565

6666
def check_clusters(dataset, split_ratios, split_names, strategy: Literal["break", "assign"], linkage: Literal["average", "single", "complete"], i: int):
67-
sorted_clusters = sorted([dataset.cluster_weights[name] for name in dataset.cluster_names], key=lambda x: x[1], reverse=True)
67+
sorted_clusters = sorted([(name, dataset.cluster_weights[name]) for name in dataset.cluster_names], key=lambda x: x[1], reverse=True)
6868
total_weight = sum(x[1] for x in sorted_clusters[i:])
6969
if [x[1] / total_weight for x in sorted_clusters[i:len(split_ratios)]] <= sorted(split_ratios, reverse=True):
7070
return None
@@ -94,9 +94,6 @@ def assign_cluster(dataset: DataSet, cluster_name: Any, split_ratios, split_name
9494
if dataset.cluster_map[n] == cluster_name:
9595
name_split_map[n] = split_name
9696
dataset.cluster_names = dataset.cluster_names[:cluster_index] + dataset.cluster_names[cluster_index + 1:]
97-
del dataset.cluster_weights[cluster_name]
98-
if dataset.cluster_stratification is not None:
99-
del dataset.cluster_stratification[cluster_name]
10097
if dataset.cluster_similarity is not None:
10198
dataset.cluster_similarity = np.delete(dataset.cluster_similarity, cluster_index, axis=0)
10299
dataset.cluster_similarity = np.delete(dataset.cluster_similarity, cluster_index, axis=1)
@@ -108,9 +105,6 @@ def assign_cluster(dataset: DataSet, cluster_name: Any, split_ratios, split_name
108105
cluster_split_map = {}
109106
name_index = dataset.names.index(cluster_name)
110107
dataset.names = dataset.names[:name_index] + dataset.names[name_index + 1:]
111-
del dataset.weights[cluster_name]
112-
if dataset.stratification is not None:
113-
del dataset.stratification[cluster_name]
114108
if dataset.similarity is not None:
115109
dataset.similarity = np.delete(dataset.similarity, name_index, axis=0)
116110
dataset.similarity = np.delete(dataset.similarity, name_index, axis=1)
@@ -179,8 +173,8 @@ def break_cluster(dataset: DataSet, cluster_name: Any, split_ratio: float, linka
179173

180174
if dataset.stratification is not None and len(dataset.classes) > 1:
181175
cluster_stratification = defaultdict(lambda: np.zeros(len(dataset.classes)))
182-
for key, value in dataset.cluster_map.items():
183-
cluster_stratification[value] += dataset.stratification[key]
176+
for name in dataset.names: # key, value in dataset.cluster_map.items():
177+
cluster_stratification[dataset.cluster_map[name]] += dataset.stratification[name]
184178
else:
185179
cluster_stratification = None
186180

0 commit comments

Comments
 (0)