Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions iris-mpc-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ num-traits.workspace = true
prost = "0.13"
rand.workspace = true
rand_distr = "0.4.3"
rayon.workspace = true
rstest = "0.23.0"
serde.workspace = true
serde_json.workspace = true
Expand Down Expand Up @@ -124,3 +125,7 @@ path = "bin/init_test_dbs.rs"
[[bin]]
name = "local_hnsw"
path = "bin/local_hnsw.rs"

[[bin]]
name = "generate_ideal_neighborhoods"
path = "bin/generate_ideal_neighborhoods.rs"
229 changes: 229 additions & 0 deletions iris-mpc-cpu/bin/generate_ideal_neighborhoods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
use std::{
fs::{File, OpenOptions},
io::{BufRead, BufReader, Write},
path::PathBuf,
};

use clap::{Parser, ValueEnum};
use iris_mpc_common::{iris_db::iris::IrisCode, IrisSerialId};
use iris_mpc_cpu::{
hawkers::naive_knn_plaintext::{naive_knn, KNNResult},
py_bindings::plaintext_store::Base64IrisCode,
};
use metrics::IntoF64;
use rayon::ThreadPoolBuilder;
use serde::{Deserialize, Serialize};
use serde_json::Deserializer;
use std::time::Instant;

#[derive(Clone, Debug, ValueEnum, Copy, Serialize, Deserialize, PartialEq)]
enum IrisSelection {
All,
Even,
Odd,
}

/// A struct to hold the metadata stored in the first line of the results file.
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ResultsHeader {
iris_selection: IrisSelection,
num_irises: usize,
k: usize,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Number of irises to process
#[arg(long, default_value_t = 1000)]
num_irises: usize,

/// Number of threads to use
#[arg(long, default_value_t = 1)]
num_threads: usize,

/// Path to the iris codes file
#[arg(long, default_value = "iris-mpc-cpu/data/store.ndjson")]
path_to_iris_codes: PathBuf,

/// The k for k-NN
#[arg(long)]
k: usize,

/// Path to the results file
#[arg(long)]
results_file: PathBuf,

/// Selection of irises to process
#[arg(long, value_enum, default_value_t = IrisSelection::All)]
irises_selection: IrisSelection,
}
#[tokio::main]
async fn main() {
let args = Args::parse();

let (num_already_processed, nodes) = match File::open(&args.results_file) {
Ok(file) => {
let reader = BufReader::new(file);
let mut lines = reader.lines();

// 1. Read and validate the header line
let header_line = match lines.next() {
Some(Ok(line)) => line,
Some(Err(e)) => {
eprintln!("Error: Could not read header from results file: {}", e);
std::process::exit(1);
}
None => {
eprintln!(
"Error: Results file '{}' is empty. Please fix or delete it.",
args.results_file.display()
);
std::process::exit(1);
}
};

let file_header: ResultsHeader = match serde_json::from_str(&header_line) {
Ok(h) => h,
Err(e) => {
eprintln!("Error: Could not parse header in results file: {}", e);
eprintln!(
" -> Please fix or delete the file '{}' and restart.",
args.results_file.display()
);
std::process::exit(1);
}
};

// 2. Check for configuration mismatches with a single comparison
let expected_header = ResultsHeader {
iris_selection: args.irises_selection,
num_irises: args.num_irises,
k: args.k,
};

if file_header != expected_header {
eprintln!("Error: Mismatch in results file configuration.");
eprintln!(" -> Expected parameters: {:?}", expected_header);
eprintln!(" -> Parameters found in file: {:?}", file_header);
eprintln!(" -> Please use a different results file or adjust the command-line arguments to match.");
std::process::exit(1);
}

// 3. Process the rest of the lines as KNN results
let results: Result<Vec<KNNResult>, _> = lines
.map(|line_result| {
let line = line_result.map_err(|e| e.to_string())?;
serde_json::from_str::<KNNResult>(&line).map_err(|e| e.to_string())
})
.collect();

let deserialized_results = match results {
Ok(res) => res,
Err(e) => {
eprintln!("Error: Failed to deserialize a result from the file.");
eprintln!("It may be corrupted from an abrupt shutdown.");
eprintln!(" -> Error details: {}", e);
eprintln!(
" -> Please fix or delete the file '{}' and restart.",
args.results_file.display()
);
std::process::exit(1);
}
};

let nodes: Vec<IrisSerialId> = deserialized_results
.into_iter()
.map(|result| result.node)
.collect();
(nodes.len(), nodes)
}
Err(_) => {
// File doesn't exist, create it and write the header.
let mut file = File::create(&args.results_file).expect("Unable to create results file");
let header = ResultsHeader {
iris_selection: args.irises_selection,
num_irises: args.num_irises,
k: args.k,
};
let header_str =
serde_json::to_string(&header).expect("Failed to serialize ResultsHeader");
writeln!(file, "{}", header_str).expect("Failed to write header to new results file");
(0, Vec::new())
}
};

if num_already_processed > 0 {
let expected_nodes: Vec<IrisSerialId> = (1..(num_already_processed + 1) as u32).collect();
if nodes != expected_nodes {
eprintln!(
"Error: The result nodes in the file are not a contiguous sequence from 1 to N."
);
eprintln!(
" -> Please fix or delete the file '{}' and restart.",
args.results_file.display()
);
std::process::exit(1);
}
}

let num_irises = args.num_irises;
let path_to_iris_codes = args.path_to_iris_codes;

assert!(num_irises > args.k);

let file = File::open(path_to_iris_codes.as_path()).unwrap();
let reader = BufReader::new(file);

let stream = Deserializer::from_reader(reader).into_iter::<Base64IrisCode>();
let mut irises: Vec<IrisCode> = Vec::with_capacity(num_irises);

let (limit, skip, step) = match args.irises_selection {
IrisSelection::All => (num_irises, 0, 1),
IrisSelection::Even => (num_irises * 2, 0, 2),
IrisSelection::Odd => (num_irises * 2, 1, 2),
};

let stream_iterator = stream
.take(limit)
.skip(skip)
.step_by(step)
.map(|json_pt| (&json_pt.unwrap()).into());
irises.extend(stream_iterator);
assert!(irises.len() == num_irises);

let start_t = Instant::now();
let pool = ThreadPoolBuilder::new()
.num_threads(args.num_threads)
.build()
.unwrap();

let mut start = num_already_processed + 1;
let chunk_size = 1000;
println!("Starting work at serial id: {}", start);
let mut evaluated_pairs = 0usize;

while start < num_irises {
let end = (start + chunk_size).min(num_irises);
let results = naive_knn(&irises, args.k, start, end, &pool);
evaluated_pairs += (end - start) * num_irises;

let mut file = OpenOptions::new()
.append(true)
.open(&args.results_file)
.expect("Unable to open results file for appending");

println!("Appending iris results from {} to {}", start, end);
for result in &results {
let json_line = serde_json::to_string(result).expect("Failed to serialize KNNResult");
writeln!(file, "{}", json_line).expect("Failed to write to results file");
}

start = end;
}
let duration = start_t.elapsed();
println!(
"naive_knn took {:?} (per evaluated pair)",
duration.into_f64() / (evaluated_pairs as f64)
);
}
2 changes: 2 additions & 0 deletions iris-mpc-cpu/src/hawkers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ pub mod plaintext_store;
pub mod shared_irises;

pub mod build_plaintext;

pub mod naive_knn_plaintext;
52 changes: 52 additions & 0 deletions iris-mpc-cpu/src/hawkers/naive_knn_plaintext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use iris_mpc_common::{iris_db::iris::IrisCode, IrisSerialId};
use rayon::{
iter::{IntoParallelIterator, ParallelIterator},
ThreadPool,
};
use serde::{Deserialize, Serialize};

use crate::hawkers::plaintext_store::fraction_ordering;

#[derive(Serialize, Deserialize)]
pub struct KNNResult {
pub node: IrisSerialId,
neighbors: Vec<IrisSerialId>,
}

pub fn naive_knn(
irises: &[IrisCode],
k: usize,
start: usize,
end: usize,
pool: &ThreadPool,
) -> Vec<KNNResult> {
pool.install(|| {
(start..end)
.collect::<Vec<_>>()
.into_par_iter()
.map(|i| {
let current_iris = &irises[i - 1];
let mut neighbors = irises
.iter()
.enumerate()
.flat_map(|(j, other_iris)| {
(i != j + 1).then_some((
(j + 1) as u32,
current_iris.get_distance_fraction(other_iris),
))
})
.collect::<Vec<_>>();
neighbors
.select_nth_unstable_by(k - 1, |lhs, rhs| fraction_ordering(&lhs.1, &rhs.1));
let mut neighbors = neighbors.drain(0..k).collect::<Vec<_>>();
neighbors.shrink_to_fit(); // just to make sure
neighbors.sort_by(|lhs, rhs| fraction_ordering(&lhs.1, &rhs.1));
let neighbors = neighbors.into_iter().map(|(i, _)| i).collect::<Vec<_>>();
KNNResult {
node: i as u32,
neighbors,
}
})
.collect::<Vec<_>>()
})
}
8 changes: 7 additions & 1 deletion iris-mpc-cpu/src/hawkers/plaintext_store.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{cmp::Ordering, sync::Arc};

use crate::{
hawkers::shared_irises::{SharedIrises, SharedIrisesRef},
Expand Down Expand Up @@ -124,6 +124,12 @@ fn fraction_less_than(dist_1: &(u16, u16), dist_2: &(u16, u16)) -> bool {
(a as u32) * (d as u32) < (b as u32) * (c as u32)
}

pub fn fraction_ordering(dist_1: &(u16, u16), dist_2: &(u16, u16)) -> Ordering {
let (a, b) = *dist_1; // a/b
let (c, d) = *dist_2; // c/d
((a as u32) * (d as u32)).cmp(&((b as u32) * (c as u32)))
}

impl VectorStore for PlaintextStore {
type QueryRef = Arc<IrisCode>;
type VectorRef = VectorId;
Expand Down
Loading