@@ -43,7 +43,7 @@ def direct_ratio_method(distances, matches, eps=1e-4):
43
43
for q in range (Q ):
44
44
d = distances [q ]
45
45
m = matches [q ]
46
- pos_idx = np .where (m == 1 )[1 ]
46
+ pos_idx = np .where (m == 1 )[0 ]
47
47
neg_idx = np .where (m == 0 )[0 ]
48
48
for i in pos_idx :
49
49
for j in neg_idx :
@@ -105,7 +105,7 @@ def lp_soft_method(distances, matches, eps=1e-4, method ='lp'):
105
105
pos = np .where (mvals == 1 )[0 ]
106
106
neg = np .where ((mvals == 0 ) | (mvals == 0.5 ))[0 ]
107
107
for i in pos :
108
- neg_sample = np .random .choice (neg , size = min (10 , len (neg )), replace = False )
108
+ neg_sample = np .random .choice (neg , size = min (100 , len (neg )), replace = False )
109
109
for j in neg_sample :
110
110
if d [i ] < d [j ]:
111
111
continue
@@ -116,7 +116,7 @@ def lp_soft_method(distances, matches, eps=1e-4, method ='lp'):
116
116
print (f"Total equations: { len (slacks )} " )
117
117
prob += pulp .lpSum (slacks )
118
118
print ("Solving LP..." )
119
- prob .solve (pulp .PULP_CBC_CMD (msg = False ))
119
+ prob .solve (pulp .PULP_CBC_CMD (msg = True ))
120
120
slack_vals = [v .value () for v in slacks ]
121
121
violations = sum (1 for v in slack_vals if v > 1e-6 )
122
122
return w_d .value (), w_m .value (), len (slacks ), violations
@@ -126,7 +126,7 @@ def main():
126
126
parser = argparse .ArgumentParser (description = 'Learn weights for vector ranking' )
127
127
parser .add_argument ('unfiltered_ground_truth' , help = 'Unfiltered Ground truth file (binary format)' )
128
128
parser .add_argument ('filtered_ground_truth' , help = 'Filtered Ground truth file (binary format)' )
129
- parser .add_argument ('filter_matches ' , help = 'Filter match file (binary match scores)' )
129
+ parser .add_argument ('filter_match_scores ' , help = 'Filter match file (binary match scores)' )
130
130
parser .add_argument ('--method' , choices = ['ratio' , 'lp' , 'pulp' ], default = 'ratio' )
131
131
parser .add_argument ('--eps' , type = float , default = 1e-4 )
132
132
parser .add_argument ('--plot' , action = 'store_true' )
@@ -138,13 +138,13 @@ def main():
138
138
print ("Done reading ground truth file" )
139
139
140
140
# Read the filter match file
141
- filter_matches = np .loadtxt (args .filter_matches , dtype = np .float32 )
142
- print (f"Filter matches shape: { filter_matches .shape } " )
141
+ filter_match_scores = np .loadtxt (args .filter_match_scores , dtype = np .float32 )
142
+ print (f"Filter matches shape: { filter_match_scores .shape } " )
143
143
print ("Done reading filter match file" )
144
144
145
145
# Validate shapes
146
- if ground_truth_indices .shape != filter_matches .shape :
147
- print (f"Shape mismatch: { ground_truth_indices .shape } vs { filter_matches .shape } " )
146
+ if ground_truth_indices .shape != filter_match_scores .shape :
147
+ print (f"Shape mismatch: { ground_truth_indices .shape } vs { filter_match_scores .shape } " )
148
148
sys .exit (1 )
149
149
150
150
@@ -154,13 +154,13 @@ def main():
154
154
# Read unfiltered ground truth (already read as ground_truth_distances)
155
155
# Read filtered match scores (assume first 100 rows from filtered, rest from unfiltered)
156
156
shape = filtered_indices .shape # or ground_truth_distances.shape
157
- filter_matches_all = np .ones (shape , dtype = np .int32 )
157
+ filter_matches_all = np .ones (shape , dtype = np .int32 ) # All filtered matches are considered valid for filtered gt
158
158
num_filtered = filtered_indices .shape [0 ]
159
159
print (f"Number of filtered queries: { num_filtered } " )
160
160
161
161
# Concatenate: first 100 from filtered, rest from unfiltered
162
162
distances = np .concatenate ([filtered_distances , ground_truth_distances ], axis = 1 )
163
- matches = np .concatenate ([filter_matches_all , filter_matches_all ], axis = 1 )
163
+ matches = np .concatenate ([filter_matches_all , filter_match_scores ], axis = 1 )
164
164
165
165
print (f"Distances shape: { distances .shape } " )
166
166
print (f"Matches shape: { matches .shape } " )
@@ -184,10 +184,10 @@ def main():
184
184
print (f"Max-normalized distances: { distances [0 ][:5 ]} " )
185
185
186
186
if args .method == 'ratio' :
187
- w_d , w_m , total_pairs , _ = direct_ratio_method (distances , filter_matches , args .eps )
187
+ w_d , w_m , total_pairs , _ = direct_ratio_method (distances , filter_match_scores , args .eps )
188
188
violations = 0
189
189
else :
190
- w_d , w_m , total_pairs , violations = lp_soft_method (distances , filter_matches , args .eps , args .method )
190
+ w_d , w_m , total_pairs , violations = lp_soft_method (distances , filter_match_scores , args .eps , args .method )
191
191
192
192
print (f"Method: { args .method } " )
193
193
print (f"w_d = { w_d :.6f} , w_m = { w_m :.6f} " )
0 commit comments