Skip to content

Commit 6b40814

Browse files
committed
wip
1 parent 021238b commit 6b40814

File tree

1 file changed

+34
-161
lines changed

1 file changed

+34
-161
lines changed

iris-mpc-cpu/src/hawkers/naive_knn_plaintext.rs

Lines changed: 34 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ pub fn naive_knn_min_fhd1(
6161
.collect::<Vec<_>>()
6262
.into_par_iter()
6363
.map(|i| {
64-
// for i // this part parallel
65-
// for j
66-
// all rotations of i vs. center of j
6764
let current_iris = &irises[i - 1];
6865
let mut neighbors = centers
6966
.iter()
@@ -95,180 +92,56 @@ pub fn naive_knn_min_fhd1(
9592
.collect::<Vec<_>>()
9693
})
9794
}
98-
fn fraction_difference_abs(a: (u16, u16), b: (u16, u16)) -> (u32, u32) {
99-
let (num1, den1) = (a.0 as u32, a.1 as u32);
100-
let (num2, den2) = (b.0 as u32, b.1 as u32);
101-
102-
// Compute a/b - c/d = (ad - cb) / bd
103-
let numerator_a = num1 * den2;
104-
let numerator_b = num2 * den1;
105-
let denominator = den1 * den2;
106-
107-
if numerator_a >= numerator_b {
108-
(numerator_a - numerator_b, denominator)
109-
} else {
110-
(numerator_b - numerator_a, denominator)
111-
}
112-
}
113-
114-
/// Returns true if (num1/den1) <= (num2/den2), where all arguments are u32.
115-
fn is_u32_fraction_leq_u32(num1: u32, den1: u32, num2: u32, den2: u32) -> bool {
116-
(num1 as u64) * (den2 as u64) <= (num2 as u64) * (den1 as u64)
117-
}
11895

11996
pub fn naive_knn_min_fhd2(
12097
irises: &[[IrisCode; 31]],
12198
centers: &[IrisCode],
122-
self_rots: &[[(u16, u16); 31]], // Pre-computed rotation profiles
12399
k: usize,
124100
start: usize,
125101
end: usize,
126102
pool: &ThreadPool,
127103
) -> Vec<KNNResult> {
128-
// B is the batch size for candidates. A larger B reduces the overhead of
129-
// select_nth_unstable but uses more temporary memory. Must be >= k.
130-
let batch_size = k;
131-
assert!(batch_size >= k, "Batch size must be at least k");
104+
let a_size = 2;
105+
let b_size = 64;
106+
107+
let n = centers.len();
108+
let len = end - start;
132109

133110
pool.install(|| {
134-
(start..end)
111+
(0..(len + a_size - 1) / a_size)
135112
.collect::<Vec<_>>()
136113
.into_par_iter()
137-
.map(|i| {
138-
let current_iris_rots = &irises[i - 1];
139-
let mut pruned = 0;
140-
// --- State for the optimization ---
141-
let mut best_neighbors: Vec<(usize, (u16, u16))> = vec![(0, (1, 1)); k];
142-
let mut candidates: Vec<(usize, (u16, u16))> = Vec::with_capacity(batch_size);
143-
// The threshold is the distance of the k-th best neighbor found so far
144-
let mut threshold = (1, 1);
145-
146-
for (j, other_iris) in centers.iter().enumerate() {
147-
// --- Stage 1: Fast Pruning ---
148-
149-
// a) Calculate the cheap, un-rotated distance (d0)
150-
// We assume the first rotation is the un-rotated "base" iris
151-
let d0 = centers[i - 1].get_distance_fraction(other_iris);
152-
153-
// b) Check if pruning is possible. If d0 is already worse than our threshold,
154-
// we can invest in calculating the lower bound to see if we can skip.
155-
if !fraction_less_than(&d0, &threshold) {
156-
// c) Calculate the provable lower bound
157-
let lower_bound = self_rots[j]
158-
.iter()
159-
.skip(12)
160-
.take(3)
161-
.map(|d_rot| fraction_difference_abs(d0, *d_rot)) // |d0 - d_rot(y,s)|
162-
.sorted_by(|&lhs, &rhs| {
163-
let (num1, den1) = lhs;
164-
let (num2, den2) = rhs;
165-
((num1 as u64) * (den2 as u64))
166-
.cmp(&((num2 as u64) * (den1 as u64)))
114+
.for_each(|a| {
115+
let left = start + a * a_size;
116+
let right = (start + (a + 1) * a_size).min(end);
117+
for b in 0..(n + b_size - 1) / b_size {
118+
for i in left..right {
119+
let current_iris = &irises[i];
120+
let dists = (b * b_size..((b + 1) * b_size).min(n))
121+
.map(|j| {
122+
(
123+
j,
124+
current_iris
125+
.iter()
126+
.map(|current_rot| {
127+
current_rot.get_distance_fraction(&centers[j])
128+
})
129+
.min()
130+
.unwrap(),
131+
)
167132
})
168-
.next()
169-
.unwrap();
170-
171-
// d) The PRUNING STEP: If the best this iris can possibly be is still
172-
// worse than our current k-th neighbor, skip it entirely.
173-
174-
dbg!(format!(
175-
"threshold: {:.6}, lower_bound: {:.6}",
176-
threshold.0 as f64 / threshold.1 as f64,
177-
lower_bound.0 as f64 / lower_bound.1 as f64
178-
));
179-
if is_u32_fraction_leq_u32(
180-
threshold.0.into(),
181-
threshold.1.into(),
182-
lower_bound.0,
183-
lower_bound.1,
184-
) {
185-
pruned += 1;
186-
continue;
187-
}
188-
}
189-
190-
// --- Stage 2: Exact Calculation ---
191-
// If not pruned, compute the full, expensive minimum distance
192-
let min_distance = current_iris_rots
193-
.iter()
194-
.map(|current_rot| current_rot.get_distance_fraction(other_iris))
195-
.min()
196-
.unwrap();
197-
198-
// Add to the batch of candidates if it's potentially better than our threshold
199-
if fraction_less_than(&min_distance, &threshold) {
200-
candidates.push((j, min_distance));
133+
.collect::<Vec<_>>();
201134
}
202-
203-
// --- Stage 3: Batch Processing ---
204-
if candidates.len() >= batch_size {
205-
// Combine the current best with the new candidates
206-
best_neighbors.append(&mut candidates); // Drains candidates
207-
208-
// Find the new top k from the combined list
209-
best_neighbors
210-
.select_nth_unstable_by(k - 1, |a, b| fraction_ordering(&a.1, &b.1));
211-
best_neighbors.truncate(k);
212-
213-
// Update the threshold to the distance of the new k-th neighbor (the worst of the best)
214-
threshold = best_neighbors.last().unwrap().1;
215-
//dbg!(threshold);
216-
}
217-
}
218-
219-
// --- Finalization ---
220-
// Process any remaining candidates in the last partial batch
221-
if !candidates.is_empty() {
222-
best_neighbors.append(&mut candidates);
223-
best_neighbors
224-
.select_nth_unstable_by(k - 1, |a, b| fraction_ordering(&a.1, &b.1));
225-
best_neighbors.truncate(k);
226135
}
227136

228-
// Sort the final list of k neighbors by distance
229-
best_neighbors.sort_unstable_by(|a, b| fraction_ordering(&a.1, &b.1));
230-
231-
// Extract just the indices for the final result
232-
let final_neighbor_indices =
233-
best_neighbors.into_iter().map(|(idx, _)| idx).collect();
234-
dbg!({ pruned });
235-
KNNResult {
236-
node: i,
237-
neighbors: final_neighbor_indices,
238-
}
239-
})
240-
.collect::<Vec<_>>()
241-
})
242-
}
243-
244-
pub fn naive_knn_min_fhd(
245-
irises: &[RotExtIrisCode],
246-
k: usize,
247-
start: usize,
248-
end: usize,
249-
pool: &ThreadPool,
250-
) -> Vec<KNNResult> {
251-
pool.install(|| {
252-
(start..end)
253-
.collect::<Vec<_>>()
254-
.into_par_iter()
255-
.map(|i| {
256-
let current_iris = &irises[i - 1];
257-
let mut neighbors = irises
258-
.iter()
259-
.enumerate()
260-
.flat_map(|(j, other_iris)| {
261-
(i != j + 1).then_some((j + 1, current_iris.min_fhe(other_iris)))
262-
})
263-
.collect::<Vec<_>>();
264-
neighbors
265-
.select_nth_unstable_by(k - 1, |lhs, rhs| fraction_ordering(&lhs.1, &rhs.1));
266-
let mut neighbors = neighbors.drain(0..k).collect::<Vec<_>>();
267-
neighbors.shrink_to_fit(); // just to make sure
268-
neighbors.sort_by(|lhs, rhs| fraction_ordering(&lhs.1, &rhs.1));
269-
let neighbors = neighbors.into_iter().map(|(i, _)| i).collect::<Vec<_>>();
270-
KNNResult { node: i, neighbors }
271-
})
272-
.collect::<Vec<_>>()
137+
// neighbors
138+
// .select_nth_unstable_by(k - 1, |lhs, rhs| fraction_ordering(&lhs.1, &rhs.1));
139+
// let mut neighbors = neighbors.drain(0..k).collect::<Vec<_>>();
140+
// neighbors.shrink_to_fit(); // just to make sure
141+
// neighbors.sort_by(|lhs, rhs| fraction_ordering(&lhs.1, &rhs.1));
142+
// let neighbors = neighbors.into_iter().map(|(i, _)| i).collect::<Vec<_>>();
143+
// KNNResult { node: i, neighbors }
144+
// })
145+
// .collect::<Vec<_>>()
273146
})
274147
}

0 commit comments

Comments
 (0)