Skip to content

Commit 1e9d165

Browse files
author
Ananya Sutradhar
committed
minor fixes in ilp pipeline
1 parent fb1ade5 commit 1e9d165

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

scripts/ml_ilp/ilp.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def direct_ratio_method(distances, matches, eps=1e-4):
4343
for q in range(Q):
4444
d = distances[q]
4545
m = matches[q]
46-
pos_idx = np.where(m == 1)[1]
46+
pos_idx = np.where(m == 1)[0]
4747
neg_idx = np.where(m == 0)[0]
4848
for i in pos_idx:
4949
for j in neg_idx:
@@ -105,7 +105,7 @@ def lp_soft_method(distances, matches, eps=1e-4, method ='lp'):
105105
pos = np.where(mvals == 1)[0]
106106
neg = np.where((mvals == 0) | (mvals == 0.5))[0]
107107
for i in pos:
108-
neg_sample = np.random.choice(neg, size=min(10, len(neg)), replace=False)
108+
neg_sample = np.random.choice(neg, size=min(100, len(neg)), replace=False)
109109
for j in neg_sample:
110110
if d[i] < d[j]:
111111
continue
@@ -116,7 +116,7 @@ def lp_soft_method(distances, matches, eps=1e-4, method ='lp'):
116116
print(f"Total equations: {len(slacks)}")
117117
prob += pulp.lpSum(slacks)
118118
print("Solving LP...")
119-
prob.solve(pulp.PULP_CBC_CMD(msg=False))
119+
prob.solve(pulp.PULP_CBC_CMD(msg=True))
120120
slack_vals = [v.value() for v in slacks]
121121
violations = sum(1 for v in slack_vals if v > 1e-6)
122122
return w_d.value(), w_m.value(), len(slacks), violations
@@ -126,7 +126,7 @@ def main():
126126
parser = argparse.ArgumentParser(description='Learn weights for vector ranking')
127127
parser.add_argument('unfiltered_ground_truth', help='Unfiltered Ground truth file (binary format)')
128128
parser.add_argument('filtered_ground_truth', help='Filtered Ground truth file (binary format)')
129-
parser.add_argument('filter_matches', help='Filter match file (binary match scores)')
129+
parser.add_argument('filter_match_scores', help='Filter match file (binary match scores)')
130130
parser.add_argument('--method', choices=['ratio', 'lp', 'pulp'], default='ratio')
131131
parser.add_argument('--eps', type=float, default=1e-4)
132132
parser.add_argument('--plot', action='store_true')
@@ -138,13 +138,13 @@ def main():
138138
print("Done reading ground truth file")
139139

140140
# Read the filter match file
141-
filter_matches = np.loadtxt(args.filter_matches, dtype=np.float32)
142-
print(f"Filter matches shape: {filter_matches.shape}")
141+
filter_match_scores = np.loadtxt(args.filter_match_scores, dtype=np.float32)
142+
print(f"Filter matches shape: {filter_match_scores.shape}")
143143
print("Done reading filter match file")
144144

145145
# Validate shapes
146-
if ground_truth_indices.shape != filter_matches.shape:
147-
print(f"Shape mismatch: {ground_truth_indices.shape} vs {filter_matches.shape}")
146+
if ground_truth_indices.shape != filter_match_scores.shape:
147+
print(f"Shape mismatch: {ground_truth_indices.shape} vs {filter_match_scores.shape}")
148148
sys.exit(1)
149149

150150

@@ -154,13 +154,13 @@ def main():
154154
# Read unfiltered ground truth (already read as ground_truth_distances)
155155
# Read filtered match scores (assume first 100 rows from filtered, rest from unfiltered)
156156
shape = filtered_indices.shape # or ground_truth_distances.shape
157-
filter_matches_all = np.ones(shape, dtype=np.int32)
157+
filter_matches_all = np.ones(shape, dtype=np.int32) # All filtered matches are considered valid for filtered gt
158158
num_filtered = filtered_indices.shape[0]
159159
print(f"Number of filtered queries: {num_filtered}")
160160

161161
# Concatenate: first 100 from filtered, rest from unfiltered
162162
distances = np.concatenate([filtered_distances, ground_truth_distances], axis=1)
163-
matches = np.concatenate([filter_matches_all, filter_matches_all], axis=1)
163+
matches = np.concatenate([filter_matches_all, filter_match_scores], axis=1)
164164

165165
print(f"Distances shape: {distances.shape}")
166166
print(f"Matches shape: {matches.shape}")
@@ -184,10 +184,10 @@ def main():
184184
print(f"Max-normalized distances: {distances[0][:5]}")
185185

186186
if args.method == 'ratio':
187-
w_d, w_m, total_pairs, _ = direct_ratio_method(distances, filter_matches, args.eps)
187+
w_d, w_m, total_pairs, _ = direct_ratio_method(distances, filter_match_scores, args.eps)
188188
violations = 0
189189
else:
190-
w_d, w_m, total_pairs, violations = lp_soft_method(distances, filter_matches, args.eps, args.method)
190+
w_d, w_m, total_pairs, violations = lp_soft_method(distances, filter_match_scores, args.eps, args.method)
191191

192192
print(f"Method: {args.method}")
193193
print(f"w_d = {w_d:.6f}, w_m = {w_m:.6f}")

scripts/ml_ilp/ilp_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def main():
171171
print("STEP 2: Calculating unfiltered ground truth")
172172

173173
unfiltered_gt_path = os.path.join(args.output_dir, f"unfiltered_groundtruth_{args.base_size}_{args.query_size}_train.bin")
174-
unfiltered_match_scores_path = os.path.join(args.output_dir, f"unfiltered_match_scores_{args.base_size}_{args.query_size}_test.txt")
175-
176-
if (check_file_exists(unfiltered_gt_path, "Unfiltered ground truth") and
174+
unfiltered_match_scores_path = os.path.join(args.output_dir, f"unfiltered_match_scores_{args.base_size}_{args.query_size}_train.txt")
175+
176+
if (check_file_exists(unfiltered_gt_path, "Unfiltered ground truth") and
177177
check_file_exists(unfiltered_match_scores_path, "Unfiltered match scores")):
178178
print("✓ Unfiltered ground truth already computed, skipping...")
179179
else:

0 commit comments

Comments
 (0)