Skip to content

Commit 7e6cc73

Browse files
author
Ananya Sutradhar
committed
updated ground truth and ILP to accept relational filters
1 parent 48640d8 commit 7e6cc73

File tree

3 files changed

+219
-27
lines changed

3 files changed

+219
-27
lines changed

apps/utils/compute_filtered_groundtruth.cpp

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,26 @@ void print_query_stats(std::vector<std::pair<uint32_t, uint32_t>> &v)
460460
return;
461461
}
462462

463+
// Add this struct and helper at the top of the file
464+
struct RelationalFilter {
465+
std::string field;
466+
std::string op;
467+
std::string value;
468+
};
469+
470+
inline bool is_relational(const std::string& label) {
471+
return label.find('<') != std::string::npos || label.find('>') != std::string::npos;
472+
}
473+
474+
inline bool eval_relational(const std::string& base_val, const std::string& op, const std::string& query_val) {
475+
float b = std::stof(base_val), q = std::stof(query_val);
476+
if (op == "<") return b < q;
477+
if (op == "<=") return b <= q;
478+
if (op == ">") return b > q;
479+
if (op == ">=") return b >= q;
480+
return false;
481+
}
482+
463483
// template<typename A, typename B>
464484
// add UNIVERSAL LABEL SUPPORT
465485
int identify_matching_points(const std::string &base, const size_t start_id, const std::string &query,
@@ -493,19 +513,52 @@ int identify_matching_points(const std::string &base, const size_t start_id, con
493513
for (uint32_t k = 0; k < query_labels[i].size(); k++)
494514
{
495515
bool or_pass = false;
496-
for (uint32_t l = 0; l < query_labels[i][k].size(); l++)
497-
{
498-
if (base_labels[j].find(query_labels[i][k][l]) != base_labels[j].end())
516+
for (uint32_t l = 0; l < query_labels[i][k].size(); l++)
499517
{
500-
or_pass = true;
518+
const std::string& qlabel = query_labels[i][k][l];
519+
if (!is_relational(qlabel)) {
520+
// Old flow: treat as set
521+
if (base_labels[j].find(qlabel) != base_labels[j].end()) {
522+
or_pass = true;
523+
break;
524+
}
525+
} else {
526+
// New flow: relational filter
527+
// Parse field, op, value from qlabel, e.g. "year<2020"
528+
size_t pos = qlabel.find_first_of("<>");
529+
std::string field = qlabel.substr(0, pos);
530+
std::string op = qlabel.substr(pos, (qlabel[pos+1] == '=') ? 2 : 1);
531+
std::string value = qlabel.substr(pos + op.size());
532+
// // Find base value for this field
533+
auto it = std::find_if(base_labels[j].begin(), base_labels[j].end(),
534+
[&](const std::string& s) { return s.find(field + "=") == 0; });
535+
// if (it != base_labels[j].end()) {
536+
// std::string base_val = it->substr(field.size() + 1);
537+
// if (eval_relational(base_val, op, value)) {
538+
// or_pass = true;
539+
// break;
540+
// }
541+
// }
542+
if (it != base_labels[j].end()) {
543+
std::string base_val = it->substr(field.size() + 1);
544+
bool match = eval_relational(base_val, op, value);
545+
// #pragma omp critical
546+
// {
547+
// std::cout << "Query: " << qlabel << ", Base: " << *it << ", Parsed: " << base_val
548+
// << ", Match: " << match << std::endl;
549+
// }
550+
if (match) {
551+
or_pass = true;
552+
break;
553+
}
554+
}
555+
}
556+
}
557+
if (!or_pass) {
558+
pass = false;
501559
break;
502560
}
503561
}
504-
if (or_pass == false) {
505-
pass = false;
506-
break;
507-
}
508-
}
509562
}
510563
if (pass)
511564
{

apps/utils/compute_groundtruth.cpp

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,14 @@ int aux_main(const std::string &base_file, const std::string &query_file, const
494494

495495
int *closest_points = new int[nqueries * k];
496496
float *dist_closest_points = new float[nqueries * k];
497-
std::vector<std::vector<float>> match_scores(nqueries, std::vector<float>(k, 0));
498497

499498
std::vector<std::vector<std::pair<uint32_t, float>>> results =
500499
processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
501500

501+
std::vector<std::vector<float>> jaccard_scores(nqueries, std::vector<float>(k, 0));
502+
std::vector<std::vector<float>> relational_scores(nqueries, std::vector<float>(k, 0));
503+
504+
502505
for (size_t i = 0; i < nqueries; i++)
503506
{
504507
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
@@ -524,29 +527,61 @@ int aux_main(const std::string &base_file, const std::string &query_file, const
524527
dist_closest_points[i * k + j] = iter.second;
525528

526529
// Calculate match score for this vector
530+
// Jaccard score (normal filters)
531+
float jaccard_similarity = 0.0f;
532+
// Relational score (relational filters)
533+
float rel_score = 0.0f;
534+
527535
if (!base_labels.empty() && !query_labels.empty() && iter.first < base_labels.size())
528536
{
529537
const auto &query_label_predicates = query_labels[i];
530538
const auto &base_label_set = base_labels[iter.first];
531539

532-
533-
// calculate jaccard distance between query and base labels
540+
// Jaccard
534541
std::set<std::string> intersection;
542+
int normal_total = 0;
543+
for (const auto &clause : query_label_predicates)
544+
{
545+
for (const auto &label : clause)
546+
{
547+
size_t pos = label.find_first_of("<>");
548+
if (pos == std::string::npos) { // normal filter
549+
normal_total++;
550+
if (base_label_set.find(label) != base_label_set.end())
551+
intersection.insert(label);
552+
}
553+
}
554+
}
555+
jaccard_similarity = (normal_total > 0) ? (float)intersection.size() / normal_total : 0.0f;
556+
557+
// Relational
535558
for (const auto &clause : query_label_predicates)
536559
{
537560
for (const auto &label : clause)
538561
{
539-
if (base_label_set.find(label) != base_label_set.end())
540-
{
541-
intersection.insert(label);
562+
size_t pos = label.find_first_of("<>");
563+
if (pos != std::string::npos) { // relational filter
564+
std::string field = label.substr(0, pos);
565+
std::string op = label.substr(pos, (label[pos+1] == '=') ? 2 : 1);
566+
std::string value = label.substr(pos + op.size());
567+
for (const auto &base_label : base_label_set)
568+
{
569+
if (base_label.find(field + "=") == 0)
570+
{
571+
float query_val = std::stof(value);
572+
float base_val = std::stof(base_label.substr(field.size() + 1));
573+
rel_score = std::abs(query_val - base_val) / query_val;
574+
break;
575+
}
576+
}
542577
}
543578
}
544579
}
545-
546-
float jaccard_distance = (float)intersection.size() / (float)query_label_predicates.size();
547-
match_scores[i][j] = jaccard_distance;
548580
}
549581

582+
jaccard_scores[i][j] = jaccard_similarity;
583+
relational_scores[i][j] = rel_score;
584+
550585
++j;
551586
}
552587
if (j < k)
@@ -564,11 +599,23 @@ int aux_main(const std::string &base_file, const std::string &query_file, const
564599
std::cerr << "Failed to open match score file: " << match_score_file << std::endl;
565600
return -1;
566601
}
602+
// First part: Jaccard scores
603+
for (size_t i = 0; i < nqueries; i++)
604+
{
605+
for (size_t j = 0; j < k; j++)
606+
{
607+
match_score_writer << jaccard_scores[i][j];
608+
if (j < k - 1)
609+
match_score_writer << " ";
610+
}
611+
match_score_writer << "\n";
612+
}
613+
// Second part: Relational scores
567614
for (size_t i = 0; i < nqueries; i++)
568615
{
569616
for (size_t j = 0; j < k; j++)
570617
{
571-
match_score_writer << match_scores[i][j];
618+
match_score_writer << relational_scores[i][j];
572619
if (j < k - 1)
573620
match_score_writer << " ";
574621
}

scripts/ml_ilp/ilp.py

Lines changed: 100 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ def read_ground_truth(file_path):
3636

3737
return indices, distances
3838

39+
def read_match_scores(match_score_file, Q, N):
40+
scores = np.loadtxt(match_score_file)
41+
if scores.shape[0] == 2 * Q:
42+
jaccard_scores = scores[:Q, :]
43+
relational_scores = scores[Q:, :]
44+
else:
45+
raise ValueError("Unexpected match score file shape")
46+
return jaccard_scores, relational_scores
47+
3948
def direct_ratio_method(distances, matches, eps=1e-4):
4049
Q, N = distances.shape
4150
max_diff = 0.0
@@ -158,15 +167,63 @@ def lp_soft_method_without_slack(distances, matches, eps=1e-4, method ='lp_wo_sl
158167
print("eps:", eps)
159168
prob.solve(pulp.PULP_CBC_CMD(msg=False))
160169
return w_m.value(), num_equations
161-
170+
171+
172+
def lp_soft_method_pulp_with_relational(distances, jaccard_scores, relational_scores, eps=1e-4):
173+
print(f"Distances shape: {distances.shape}")
174+
Q, N = distances.shape
175+
print("using PuLP with relational filter weight")
176+
prob = pulp.LpProblem('VectorRanking', pulp.LpMinimize)
177+
w_d = 1
178+
w_m = pulp.LpVariable('w_m', lowBound=0)
179+
w_r = pulp.LpVariable('w_r', lowBound=0)
180+
slacks = []
181+
182+
for q in tqdm(range(Q), desc="Building PuLP constraints"):
183+
d = distances[q]
184+
jac = jaccard_scores[q]
185+
rel = relational_scores[q]
186+
# Positive: jaccard==1 and (relational==0 if relational filter exists)
187+
has_relational = np.any(rel != 0)
188+
if has_relational:
189+
pos = np.where((jac == 1) & (rel == 0))[0]
190+
neg = np.where(~((jac == 1) & (rel == 0)))[0]
191+
else:
192+
pos = np.where(jac == 1)[0]
193+
neg = np.where(jac < 1)[0]
194+
neg_sample_size = min(1, len(neg))
195+
196+
for i in pos:
197+
neg_sample = np.random.choice(neg, size=neg_sample_size, replace=False)
198+
for j in neg_sample:
199+
if d[i] < d[j]:
200+
continue
201+
s = pulp.LpVariable(f's_{q}_{i}_{j}', lowBound=0)
202+
slacks.append(s)
203+
prob += (
204+
w_d * d[i] + w_m * (1 - jac[i]) + w_r * rel[i] + eps
205+
<= w_d * d[j] + w_m * (1 - jac[j]) + w_r * rel[j] + s
206+
)
207+
print(f"Total equations: {len(slacks)}")
208+
alpha = 500
209+
if len(slacks) > 0:
210+
avg_slack = pulp.lpSum(slacks) / len(slacks)
211+
else:
212+
avg_slack = 0
213+
prob += w_m + w_r + alpha * avg_slack
214+
print("Solving LP...")
215+
prob.solve(pulp.PULP_CBC_CMD(msg=False))
216+
slack_vals = [v.value() for v in slacks]
217+
violations = sum(1 for v in slack_vals if v > 1e-6)
218+
return w_d, w_m.value(), w_r.value(), len(slacks), violations
162219

163220

164221
def main():
165222
parser = argparse.ArgumentParser(description='Learn weights for vector ranking')
166223
parser.add_argument('unfiltered_ground_truth', help='Unfiltered Ground truth file (binary format)')
167224
parser.add_argument('filtered_ground_truth', help='Filtered Ground truth file (binary format)')
168225
parser.add_argument('unfiltered_match_scores', help='Filter match file (match scores)')
169-
parser.add_argument('--method', choices=['ratio', 'gekko', 'pulp', 'pulp_wo_slack'], default='ratio')
226+
parser.add_argument('--method', choices=['ratio', 'gekko', 'pulp', 'pulp_wo_slack', 'pulp_w_relational'], default='ratio')
170227
parser.add_argument('--eps', type=float, default=1e-4)
171228
parser.add_argument('--plot', action='store_true')
172229
args = parser.parse_args()
@@ -180,10 +237,10 @@ def main():
180237
print(f"Filter matches shape: {unfiltered_match_scores.shape}")
181238
print("Done reading filter match file")
182239

183-
# Validate shapes
184-
if unfiltered_gt_indices.shape != unfiltered_match_scores.shape:
185-
print(f"Shape mismatch: {unfiltered_gt_indices.shape} vs {unfiltered_match_scores.shape}")
186-
sys.exit(1)
240+
# # Validate shapes
241+
# if unfiltered_gt_indices.shape != unfiltered_match_scores.shape:
242+
# print(f"Shape mismatch: {unfiltered_gt_indices.shape} vs {unfiltered_match_scores.shape}")
243+
# sys.exit(1)
187244

188245

189246
# Concatenate filtered and unfiltered ground truth distances and match scores
@@ -196,10 +253,10 @@ def main():
196253
print(f"Number of filtered queries: {num_filtered}")
197254

198255
distances = np.concatenate([filtered_gt_distances, unfiltered_gt_distances], axis=1)
199-
matches = np.concatenate([filtered_match_score, unfiltered_match_scores], axis=1)
256+
# matches = np.concatenate([filtered_match_score, unfiltered_match_scores], axis=1)
200257

201258
print(f"Distances shape: {distances.shape}")
202-
print(f"Matches shape: {matches.shape}")
259+
# print(f"Matches shape: {matches.shape}")
203260

204261
print(f"Distances: {distances[0][:5]}")
205262
# distances_scaled = distances / distances.max()
@@ -215,6 +272,41 @@ def main():
215272
w_d, w_m, total_pairs, violations = lp_soft_method_gekko(distances, unfiltered_match_scores, args.eps)
216273
if args.method == 'pulp':
217274
w_d, w_m, total_pairs, violations = lp_soft_method_pulp(distances, unfiltered_match_scores, args.eps)
275+
if args.method == 'pulp_w_relational':
276+
unfiltered_jaccard_scores, unfiltered_relational_scores = read_match_scores(args.unfiltered_match_scores, *distances.shape)
277+
print(f"Unfiltered Jaccard scores shape: {unfiltered_jaccard_scores.shape}")
278+
print(f"Unfiltered Relational scores shape: {unfiltered_relational_scores.shape}")
279+
max_rel = np.max(unfiltered_relational_scores, axis=1, keepdims=True)
280+
max_rel[max_rel == 0] = 1.0
281+
unfiltered_relational_scores = unfiltered_relational_scores / max_rel
282+
283+
print(f"Relational scores: {unfiltered_relational_scores[0][:5]}")
284+
285+
filtered_jaccard_scores = np.ones_like(filtered_gt_distances, dtype=np.float32)
286+
filtered_relational_scores = np.zeros_like(filtered_gt_distances, dtype=np.float32)
287+
288+
jaccard_scores = np.concatenate([filtered_jaccard_scores, unfiltered_jaccard_scores], axis=1)
289+
relational_scores = np.concatenate([filtered_relational_scores, unfiltered_relational_scores], axis=1)
290+
relational_scores = relational_scores.astype(np.float32)
291+
292+
print(f"Jaccard scores shape: {jaccard_scores.shape}")
293+
print(f"Relational scores shape: {relational_scores.shape}")
294+
295+
print(f"Relational scores: {relational_scores[0]}")
296+
297+
# # take the first 100 queries
298+
# jaccard_scores = jaccard_scores[:1000]
299+
# relational_scores = relational_scores[:1000]
300+
# distances = distances[:1000]
301+
print(f"Filtered Jaccard scores shape: {jaccard_scores.shape}")
302+
print(f"Filtered Relational scores shape: {relational_scores.shape}")
303+
304+
305+
306+
w_d, w_m, w_r, total_pairs, violations = lp_soft_method_pulp_with_relational(
307+
distances, jaccard_scores, relational_scores, args.eps
308+
)
309+
print(f"Relational weight: {w_r:.6f}")
218310

219311
print(f"Method: {args.method}")
220312
print(f"Total pairs evaluated: {total_pairs}")

0 commit comments

Comments
 (0)