-
Notifications
You must be signed in to change notification settings - Fork 20
[POP-2951] Slow-but-perfect KNN in plaintext #1676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+295
−1
Merged
Changes from 7 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
b03a716
draft naive-knn
mcalancea b32449d
Release memory correctly add checkpoints and safeguards
mcalancea 9ea42f6
clippy
mcalancea 575b15f
Merge branch 'main' into mihai/ideal-neighborhoods-gen
mcalancea 39625c8
fixes
mcalancea a0073da
change floats to fractions
mcalancea efaf4a9
dev: deploy main
naure 3bcafc9
dev: Increase timeout to 30min to support batch_size=32
naure 7878a72
Merge branch 'main' into mihai/ideal-neighborhoods-gen
mcalancea 01d5824
Merge branch 'main' into dev
mcalancea d12a2b3
Merge branch 'dev' into mihai/ideal-neighborhoods-gen
mcalancea 55adf81
Merge branch 'main' into dev
mcalancea 9844546
[POP-2929] add graceful shutdown to the networking stack (#1685)
sdwoodbury 1cf5cbd
u32 instead of usize and stream.take
mcalancea e160325
Merge main -> dev (#1705)
naure bb22b9a
Merge branch 'dev' into mihai/ideal-neighborhoods-gen
mcalancea d90538b
clippy
mcalancea 9472803
Merge branch 'dev' into mihai/ideal-neighborhoods-gen
bgillesp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
use std::{ | ||
fs::{File, OpenOptions}, | ||
io::{BufRead, BufReader, Write}, | ||
path::PathBuf, | ||
}; | ||
|
||
use clap::{Parser, ValueEnum}; | ||
use iris_mpc_common::iris_db::iris::IrisCode; | ||
use iris_mpc_cpu::{ | ||
hawkers::naive_knn_plaintext::{naive_knn, KNNResult}, | ||
py_bindings::{limited_iterator, plaintext_store::Base64IrisCode}, | ||
}; | ||
use metrics::IntoF64; | ||
use rayon::ThreadPoolBuilder; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::Deserializer; | ||
use std::time::Instant; | ||
|
||
#[derive(Clone, Debug, ValueEnum, Copy, Serialize, Deserialize, PartialEq)] | ||
enum IrisSelection { | ||
All, | ||
Even, | ||
Odd, | ||
} | ||
|
||
/// A struct to hold the metadata stored in the first line of the results file. | ||
#[derive(Serialize, Deserialize, PartialEq, Debug)] | ||
struct ResultsHeader { | ||
iris_selection: IrisSelection, | ||
num_irises: usize, | ||
k: usize, | ||
} | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// Number of irises to process | ||
#[arg(long, default_value_t = 1000)] | ||
num_irises: usize, | ||
|
||
/// Number of threads to use | ||
#[arg(long, default_value_t = 1)] | ||
num_threads: usize, | ||
|
||
/// Path to the iris codes file | ||
#[arg(long, default_value = "iris-mpc-cpu/data/store.ndjson")] | ||
path_to_iris_codes: PathBuf, | ||
|
||
/// The k for k-NN | ||
#[arg(long)] | ||
k: usize, | ||
bgillesp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/// Path to the results file | ||
#[arg(long)] | ||
results_file: PathBuf, | ||
bgillesp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/// Selection of irises to process | ||
#[arg(long, value_enum, default_value_t = IrisSelection::All)] | ||
irises_selection: IrisSelection, | ||
} | ||
#[tokio::main] | ||
async fn main() { | ||
let args = Args::parse(); | ||
|
||
let (num_already_processed, nodes) = match File::open(&args.results_file) { | ||
Ok(file) => { | ||
let reader = BufReader::new(file); | ||
let mut lines = reader.lines(); | ||
|
||
// 1. Read and validate the header line | ||
let header_line = match lines.next() { | ||
Some(Ok(line)) => line, | ||
Some(Err(e)) => { | ||
eprintln!("Error: Could not read header from results file: {}", e); | ||
std::process::exit(1); | ||
} | ||
None => { | ||
eprintln!( | ||
"Error: Results file '{}' is empty. Please fix or delete it.", | ||
args.results_file.display() | ||
); | ||
std::process::exit(1); | ||
} | ||
}; | ||
|
||
let file_header: ResultsHeader = match serde_json::from_str(&header_line) { | ||
Ok(h) => h, | ||
Err(e) => { | ||
eprintln!("Error: Could not parse header in results file: {}", e); | ||
eprintln!( | ||
" -> Please fix or delete the file '{}' and restart.", | ||
args.results_file.display() | ||
); | ||
std::process::exit(1); | ||
} | ||
}; | ||
|
||
// 2. Check for configuration mismatches with a single comparison | ||
let expected_header = ResultsHeader { | ||
iris_selection: args.irises_selection, | ||
num_irises: args.num_irises, | ||
k: args.k, | ||
}; | ||
|
||
if file_header != expected_header { | ||
eprintln!("Error: Mismatch in results file configuration."); | ||
eprintln!(" -> Expected parameters: {:?}", expected_header); | ||
eprintln!(" -> Parameters found in file: {:?}", file_header); | ||
eprintln!(" -> Please use a different results file or adjust the command-line arguments to match."); | ||
std::process::exit(1); | ||
} | ||
|
||
// 3. Process the rest of the lines as KNN results | ||
let results: Result<Vec<KNNResult>, _> = lines | ||
.map(|line_result| { | ||
let line = line_result.map_err(|e| e.to_string())?; | ||
serde_json::from_str::<KNNResult>(&line).map_err(|e| e.to_string()) | ||
}) | ||
.collect(); | ||
|
||
let deserialized_results = match results { | ||
Ok(res) => res, | ||
Err(e) => { | ||
eprintln!("Error: Failed to deserialize a result from the file."); | ||
eprintln!("It may be corrupted from an abrupt shutdown."); | ||
eprintln!(" -> Error details: {}", e); | ||
eprintln!( | ||
" -> Please fix or delete the file '{}' and restart.", | ||
args.results_file.display() | ||
); | ||
std::process::exit(1); | ||
} | ||
}; | ||
|
||
let nodes: Vec<usize> = deserialized_results | ||
.into_iter() | ||
.map(|result| result.node) | ||
.collect(); | ||
(nodes.len(), nodes) | ||
} | ||
Err(_) => { | ||
// File doesn't exist, create it and write the header. | ||
let mut file = File::create(&args.results_file).expect("Unable to create results file"); | ||
let header = ResultsHeader { | ||
iris_selection: args.irises_selection, | ||
num_irises: args.num_irises, | ||
k: args.k, | ||
}; | ||
let header_str = | ||
serde_json::to_string(&header).expect("Failed to serialize ResultsHeader"); | ||
writeln!(file, "{}", header_str).expect("Failed to write header to new results file"); | ||
(0, Vec::new()) | ||
} | ||
}; | ||
|
||
if num_already_processed > 0 { | ||
let expected_nodes: Vec<usize> = (1..num_already_processed + 1).collect(); | ||
if nodes != expected_nodes { | ||
eprintln!( | ||
"Error: The result nodes in the file are not a contiguous sequence from 1 to N." | ||
); | ||
eprintln!( | ||
" -> Please fix or delete the file '{}' and restart.", | ||
args.results_file.display() | ||
); | ||
std::process::exit(1); | ||
} | ||
} | ||
|
||
let num_irises = args.num_irises; | ||
let path_to_iris_codes = args.path_to_iris_codes; | ||
|
||
assert!(num_irises > args.k); | ||
|
||
let file = File::open(path_to_iris_codes.as_path()).unwrap(); | ||
let reader = BufReader::new(file); | ||
|
||
let stream = Deserializer::from_reader(reader).into_iter::<Base64IrisCode>(); | ||
let mut irises: Vec<IrisCode> = Vec::with_capacity(num_irises); | ||
|
||
let (limit, skip, step) = match args.irises_selection { | ||
IrisSelection::All => (num_irises, 0, 1), | ||
IrisSelection::Even => (num_irises * 2, 0, 2), | ||
IrisSelection::Odd => (num_irises * 2, 1, 2), | ||
}; | ||
|
||
let stream_iterator = limited_iterator(stream, Some(limit)) | ||
mcalancea marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
.skip(skip) | ||
.step_by(step) | ||
.map(|json_pt| (&json_pt.unwrap()).into()); | ||
irises.extend(stream_iterator); | ||
assert!(irises.len() == num_irises); | ||
|
||
let start_t = Instant::now(); | ||
let pool = ThreadPoolBuilder::new() | ||
.num_threads(args.num_threads) | ||
.build() | ||
.unwrap(); | ||
|
||
let mut start = num_already_processed + 1; | ||
let chunk_size = 1000; | ||
println!("Starting work at serial id: {}", start); | ||
let mut evaluated_pairs = 0usize; | ||
|
||
while start < num_irises { | ||
let end = (start + chunk_size).min(num_irises); | ||
let results = naive_knn(&irises, args.k, start, end, &pool); | ||
evaluated_pairs += (end - start) * num_irises; | ||
|
||
let mut file = OpenOptions::new() | ||
.append(true) | ||
.open(&args.results_file) | ||
.expect("Unable to open results file for appending"); | ||
|
||
println!("Appending iris results from {} to {}", start, end); | ||
for result in &results { | ||
let json_line = serde_json::to_string(result).expect("Failed to serialize KNNResult"); | ||
writeln!(file, "{}", json_line).expect("Failed to write to results file"); | ||
} | ||
|
||
start = end; | ||
} | ||
let duration = start_t.elapsed(); | ||
println!( | ||
"naive_knn took {:?} (per evaluated pair)", | ||
duration.into_f64() / (evaluated_pairs as f64) | ||
); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,3 +21,5 @@ pub mod plaintext_store; | |
pub mod shared_irises; | ||
|
||
pub mod build_plaintext; | ||
|
||
pub mod naive_knn_plaintext; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
use iris_mpc_common::iris_db::iris::IrisCode; | ||
use rayon::{ | ||
iter::{IntoParallelIterator, ParallelIterator}, | ||
ThreadPool, | ||
}; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::hawkers::plaintext_store::fraction_ordering; | ||
|
||
#[derive(Serialize, Deserialize)] | ||
pub struct KNNResult { | ||
pub node: usize, | ||
neighbors: Vec<usize>, | ||
mcalancea marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} | ||
|
||
pub fn naive_knn( | ||
irises: &[IrisCode], | ||
k: usize, | ||
start: usize, | ||
end: usize, | ||
pool: &ThreadPool, | ||
) -> Vec<KNNResult> { | ||
pool.install(|| { | ||
(start..end) | ||
.collect::<Vec<_>>() | ||
.into_par_iter() | ||
.map(|i| { | ||
let current_iris = &irises[i - 1]; | ||
let mut neighbors = irises | ||
.iter() | ||
.enumerate() | ||
.flat_map(|(j, other_iris)| { | ||
(i != j + 1) | ||
.then_some((j + 1, current_iris.get_distance_fraction(other_iris))) | ||
}) | ||
.collect::<Vec<_>>(); | ||
neighbors | ||
.select_nth_unstable_by(k - 1, |lhs, rhs| fraction_ordering(&lhs.1, &rhs.1)); | ||
let mut neighbors = neighbors.drain(0..k).collect::<Vec<_>>(); | ||
neighbors.shrink_to_fit(); // just to make sure | ||
neighbors.sort_by(|lhs, rhs| fraction_ordering(&lhs.1, &rhs.1)); | ||
let neighbors = neighbors.into_iter().map(|(i, _)| i).collect::<Vec<_>>(); | ||
KNNResult { node: i, neighbors } | ||
}) | ||
.collect::<Vec<_>>() | ||
}) | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.