@@ -36,6 +36,15 @@ def read_ground_truth(file_path):
36
36
37
37
return indices , distances
38
38
39
+ def read_match_scores (match_score_file , Q , N ):
40
+ scores = np .loadtxt (match_score_file )
41
+ if scores .shape [0 ] == 2 * Q :
42
+ jaccard_scores = scores [:Q , :]
43
+ relational_scores = scores [Q :, :]
44
+ else :
45
+ raise ValueError ("Unexpected match score file shape" )
46
+ return jaccard_scores , relational_scores
47
+
39
48
def direct_ratio_method (distances , matches , eps = 1e-4 ):
40
49
Q , N = distances .shape
41
50
max_diff = 0.0
@@ -158,15 +167,63 @@ def lp_soft_method_without_slack(distances, matches, eps=1e-4, method ='lp_wo_sl
158
167
print ("eps:" , eps )
159
168
prob .solve (pulp .PULP_CBC_CMD (msg = False ))
160
169
return w_m .value (), num_equations
161
-
170
+
171
+
172
+ def lp_soft_method_pulp_with_relational (distances , jaccard_scores , relational_scores , eps = 1e-4 ):
173
+ print (f"Distances shape: { distances .shape } " )
174
+ Q , N = distances .shape
175
+ print ("using PuLP with relational filter weight" )
176
+ prob = pulp .LpProblem ('VectorRanking' , pulp .LpMinimize )
177
+ w_d = 1
178
+ w_m = pulp .LpVariable ('w_m' , lowBound = 0 )
179
+ w_r = pulp .LpVariable ('w_r' , lowBound = 0 )
180
+ slacks = []
181
+
182
+ for q in tqdm (range (Q ), desc = "Building PuLP constraints" ):
183
+ d = distances [q ]
184
+ jac = jaccard_scores [q ]
185
+ rel = relational_scores [q ]
186
+ # Positive: jaccard==1 and (relational==0 if relational filter exists)
187
+ has_relational = np .any (rel != 0 )
188
+ if has_relational :
189
+ pos = np .where ((jac == 1 ) & (rel == 0 ))[0 ]
190
+ neg = np .where (~ ((jac == 1 ) & (rel == 0 )))[0 ]
191
+ else :
192
+ pos = np .where (jac == 1 )[0 ]
193
+ neg = np .where (jac < 1 )[0 ]
194
+ neg_sample_size = min (1 , len (neg ))
195
+
196
+ for i in pos :
197
+ neg_sample = np .random .choice (neg , size = neg_sample_size , replace = False )
198
+ for j in neg_sample :
199
+ if d [i ] < d [j ]:
200
+ continue
201
+ s = pulp .LpVariable (f's_{ q } _{ i } _{ j } ' , lowBound = 0 )
202
+ slacks .append (s )
203
+ prob += (
204
+ w_d * d [i ] + w_m * (1 - jac [i ]) + w_r * rel [i ] + eps
205
+ <= w_d * d [j ] + w_m * (1 - jac [j ]) + w_r * rel [j ] + s
206
+ )
207
+ print (f"Total equations: { len (slacks )} " )
208
+ alpha = 500
209
+ if len (slacks ) > 0 :
210
+ avg_slack = pulp .lpSum (slacks ) / len (slacks )
211
+ else :
212
+ avg_slack = 0
213
+ prob += w_m + w_r + alpha * avg_slack
214
+ print ("Solving LP..." )
215
+ prob .solve (pulp .PULP_CBC_CMD (msg = False ))
216
+ slack_vals = [v .value () for v in slacks ]
217
+ violations = sum (1 for v in slack_vals if v > 1e-6 )
218
+ return w_d , w_m .value (), w_r .value (), len (slacks ), violations
162
219
163
220
164
221
def main ():
165
222
parser = argparse .ArgumentParser (description = 'Learn weights for vector ranking' )
166
223
parser .add_argument ('unfiltered_ground_truth' , help = 'Unfiltered Ground truth file (binary format)' )
167
224
parser .add_argument ('filtered_ground_truth' , help = 'Filtered Ground truth file (binary format)' )
168
225
parser .add_argument ('unfiltered_match_scores' , help = 'Filter match file (match scores)' )
169
- parser .add_argument ('--method' , choices = ['ratio' , 'gekko' , 'pulp' , 'pulp_wo_slack' ], default = 'ratio' )
226
+ parser .add_argument ('--method' , choices = ['ratio' , 'gekko' , 'pulp' , 'pulp_wo_slack' , 'pulp_w_relational' ], default = 'ratio' )
170
227
parser .add_argument ('--eps' , type = float , default = 1e-4 )
171
228
parser .add_argument ('--plot' , action = 'store_true' )
172
229
args = parser .parse_args ()
@@ -180,10 +237,10 @@ def main():
180
237
print (f"Filter matches shape: { unfiltered_match_scores .shape } " )
181
238
print ("Done reading filter match file" )
182
239
183
- # Validate shapes
184
- if unfiltered_gt_indices .shape != unfiltered_match_scores .shape :
185
- print (f"Shape mismatch: { unfiltered_gt_indices .shape } vs { unfiltered_match_scores .shape } " )
186
- sys .exit (1 )
240
+ # # Validate shapes
241
+ # if unfiltered_gt_indices.shape != unfiltered_match_scores.shape:
242
+ # print(f"Shape mismatch: {unfiltered_gt_indices.shape} vs {unfiltered_match_scores.shape}")
243
+ # sys.exit(1)
187
244
188
245
189
246
# Concatenate filtered and unfiltered ground truth distances and match scores
@@ -196,10 +253,10 @@ def main():
196
253
print (f"Number of filtered queries: { num_filtered } " )
197
254
198
255
distances = np .concatenate ([filtered_gt_distances , unfiltered_gt_distances ], axis = 1 )
199
- matches = np .concatenate ([filtered_match_score , unfiltered_match_scores ], axis = 1 )
256
+ # matches = np.concatenate([filtered_match_score, unfiltered_match_scores], axis=1)
200
257
201
258
print (f"Distances shape: { distances .shape } " )
202
- print (f"Matches shape: { matches .shape } " )
259
+ # print(f"Matches shape: {matches.shape}")
203
260
204
261
print (f"Distances: { distances [0 ][:5 ]} " )
205
262
# distances_scaled = distances / distances.max()
@@ -215,6 +272,41 @@ def main():
215
272
w_d , w_m , total_pairs , violations = lp_soft_method_gekko (distances , unfiltered_match_scores , args .eps )
216
273
if args .method == 'pulp' :
217
274
w_d , w_m , total_pairs , violations = lp_soft_method_pulp (distances , unfiltered_match_scores , args .eps )
275
+ if args .method == 'pulp_w_relational' :
276
+ unfiltered_jaccard_scores , unfiltered_relational_scores = read_match_scores (args .unfiltered_match_scores , * distances .shape )
277
+ print (f"Unfiltered Jaccard scores shape: { unfiltered_jaccard_scores .shape } " )
278
+ print (f"Unfiltered Relational scores shape: { unfiltered_relational_scores .shape } " )
279
+ max_rel = np .max (unfiltered_relational_scores , axis = 1 , keepdims = True )
280
+ max_rel [max_rel == 0 ] = 1.0
281
+ unfiltered_relational_scores = unfiltered_relational_scores / max_rel
282
+
283
+ print (f"Relational scores: { unfiltered_relational_scores [0 ][:5 ]} " )
284
+
285
+ filtered_jaccard_scores = np .ones_like (filtered_gt_distances , dtype = np .float32 )
286
+ filtered_relational_scores = np .zeros_like (filtered_gt_distances , dtype = np .float32 )
287
+
288
+ jaccard_scores = np .concatenate ([filtered_jaccard_scores , unfiltered_jaccard_scores ], axis = 1 )
289
+ relational_scores = np .concatenate ([filtered_relational_scores , unfiltered_relational_scores ], axis = 1 )
290
+ relational_scores = relational_scores .astype (np .float32 )
291
+
292
+ print (f"Jaccard scores shape: { jaccard_scores .shape } " )
293
+ print (f"Relational scores shape: { relational_scores .shape } " )
294
+
295
+ print (f"Relational scores: { relational_scores [0 ]} " )
296
+
297
+ # # take the first 100 queries
298
+ # jaccard_scores = jaccard_scores[:1000]
299
+ # relational_scores = relational_scores[:1000]
300
+ # distances = distances[:1000]
301
+ print (f"Filtered Jaccard scores shape: { jaccard_scores .shape } " )
302
+ print (f"Filtered Relational scores shape: { relational_scores .shape } " )
303
+
304
+
305
+
306
+ w_d , w_m , w_r , total_pairs , violations = lp_soft_method_pulp_with_relational (
307
+ distances , jaccard_scores , relational_scores , args .eps
308
+ )
309
+ print (f"Relational weight: { w_r :.6f} " )
218
310
219
311
print (f"Method: { args .method } " )
220
312
print (f"Total pairs evaluated: { total_pairs } " )
0 commit comments