Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions iris-mpc-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ num-traits.workspace = true
prost = "0.13"
rand.workspace = true
rand_distr = "0.4.3"
rayon.workspace = true
rstest = "0.23.0"
serde.workspace = true
serde_json.workspace = true
Expand Down Expand Up @@ -128,3 +129,7 @@ path = "bin/graph_mem_cli.rs"
[[bin]]
name = "local_hnsw"
path = "bin/local_hnsw.rs"

[[bin]]
name = "generate_ideal_neighborhoods"
path = "bin/generate_ideal_neighborhoods.rs"
228 changes: 228 additions & 0 deletions iris-mpc-cpu/bin/generate_ideal_neighborhoods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
use std::{
fs::{File, OpenOptions},
io::{BufRead, BufReader, Write},
path::PathBuf,
};

use clap::{Parser, ValueEnum};
use iris_mpc_common::iris_db::iris::IrisCode;
use iris_mpc_cpu::{
hawkers::naive_knn_plaintext::{naive_knn, KNNResult},
py_bindings::{limited_iterator, plaintext_store::Base64IrisCode},
};
use metrics::IntoF64;
use rayon::ThreadPoolBuilder;
use serde::{Deserialize, Serialize};
use serde_json::Deserializer;
use std::time::Instant;

#[derive(Clone, Debug, ValueEnum, Copy, Serialize, Deserialize, PartialEq)]
enum IrisSelection {
All,
Even,
Odd,
}

/// A struct to hold the metadata stored in the first line of the results file.
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ResultsHeader {
iris_selection: IrisSelection,
num_irises: usize,
k: usize,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Number of irises to process
#[arg(long, default_value_t = 1000)]
num_irises: usize,

/// Number of threads to use
#[arg(long, default_value_t = 1)]
num_threads: usize,

/// Path to the iris codes file
#[arg(long, default_value = "iris-mpc-cpu/data/store.ndjson")]
path_to_iris_codes: PathBuf,

/// The k for k-NN
#[arg(long)]
k: usize,

/// Path to the results file
#[arg(long)]
results_file: PathBuf,

/// Selection of irises to process
#[arg(long, value_enum, default_value_t = IrisSelection::All)]
irises_selection: IrisSelection,
}
#[tokio::main]
async fn main() {
let args = Args::parse();

let (num_already_processed, nodes) = match File::open(&args.results_file) {
Ok(file) => {
let reader = BufReader::new(file);
let mut lines = reader.lines();

// 1. Read and validate the header line
let header_line = match lines.next() {
Some(Ok(line)) => line,
Some(Err(e)) => {
eprintln!("Error: Could not read header from results file: {}", e);
std::process::exit(1);
}
None => {
eprintln!(
"Error: Results file '{}' is empty. Please fix or delete it.",
args.results_file.display()
);
std::process::exit(1);
}
};

let file_header: ResultsHeader = match serde_json::from_str(&header_line) {
Ok(h) => h,
Err(e) => {
eprintln!("Error: Could not parse header in results file: {}", e);
eprintln!(
" -> Please fix or delete the file '{}' and restart.",
args.results_file.display()
);
std::process::exit(1);
}
};

// 2. Check for configuration mismatches with a single comparison
let expected_header = ResultsHeader {
iris_selection: args.irises_selection,
num_irises: args.num_irises,
k: args.k,
};

if file_header != expected_header {
eprintln!("Error: Mismatch in results file configuration.");
eprintln!(" -> Expected parameters: {:?}", expected_header);
eprintln!(" -> Parameters found in file: {:?}", file_header);
eprintln!(" -> Please use a different results file or adjust the command-line arguments to match.");
std::process::exit(1);
}

// 3. Process the rest of the lines as KNN results
let results: Result<Vec<KNNResult>, _> = lines
.map(|line_result| {
let line = line_result.map_err(|e| e.to_string())?;
serde_json::from_str::<KNNResult>(&line).map_err(|e| e.to_string())
})
.collect();

let deserialized_results = match results {
Ok(res) => res,
Err(e) => {
eprintln!("Error: Failed to deserialize a result from the file.");
eprintln!("It may be corrupted from an abrupt shutdown.");
eprintln!(" -> Error details: {}", e);
eprintln!(
" -> Please fix or delete the file '{}' and restart.",
args.results_file.display()
);
std::process::exit(1);
}
};

let nodes: Vec<usize> = deserialized_results
.into_iter()
.map(|result| result.node)
.collect();
(nodes.len(), nodes)
}
Err(_) => {
// File doesn't exist, create it and write the header.
let mut file = File::create(&args.results_file).expect("Unable to create results file");
let header = ResultsHeader {
iris_selection: args.irises_selection,
num_irises: args.num_irises,
k: args.k,
};
let header_str =
serde_json::to_string(&header).expect("Failed to serialize ResultsHeader");
writeln!(file, "{}", header_str).expect("Failed to write header to new results file");
(0, Vec::new())
}
};

if num_already_processed > 0 {
let expected_nodes: Vec<usize> = (1..num_already_processed + 1).collect();
if nodes != expected_nodes {
eprintln!(
"Error: The result nodes in the file are not a contiguous sequence from 1 to N."
);
eprintln!(
" -> Please fix or delete the file '{}' and restart.",
args.results_file.display()
);
std::process::exit(1);
}
}

let num_irises = args.num_irises;
let path_to_iris_codes = args.path_to_iris_codes;

assert!(num_irises > args.k);

let file = File::open(path_to_iris_codes.as_path()).unwrap();
let reader = BufReader::new(file);

let stream = Deserializer::from_reader(reader).into_iter::<Base64IrisCode>();
let mut irises: Vec<IrisCode> = Vec::with_capacity(num_irises);

let (limit, skip, step) = match args.irises_selection {
IrisSelection::All => (num_irises, 0, 1),
IrisSelection::Even => (num_irises * 2, 0, 2),
IrisSelection::Odd => (num_irises * 2, 1, 2),
};

let stream_iterator = limited_iterator(stream, Some(limit))
.skip(skip)
.step_by(step)
.map(|json_pt| (&json_pt.unwrap()).into());
irises.extend(stream_iterator);
assert!(irises.len() == num_irises);

let start_t = Instant::now();
let pool = ThreadPoolBuilder::new()
.num_threads(args.num_threads)
.build()
.unwrap();

let mut start = num_already_processed + 1;
let chunk_size = 1000;
println!("Starting work at serial id: {}", start);
let mut evaluated_pairs = 0usize;

while start < num_irises {
let end = (start + chunk_size).min(num_irises);
let results = naive_knn(&irises, args.k, start, end, &pool);
evaluated_pairs += (end - start) * num_irises;

let mut file = OpenOptions::new()
.append(true)
.open(&args.results_file)
.expect("Unable to open results file for appending");

println!("Appending iris results from {} to {}", start, end);
for result in &results {
let json_line = serde_json::to_string(result).expect("Failed to serialize KNNResult");
writeln!(file, "{}", json_line).expect("Failed to write to results file");
}

start = end;
}
let duration = start_t.elapsed();
println!(
"naive_knn took {:?} (per evaluated pair)",
duration.into_f64() / (evaluated_pairs as f64)
);
}
2 changes: 2 additions & 0 deletions iris-mpc-cpu/src/hawkers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ pub mod plaintext_store;
pub mod shared_irises;

pub mod build_plaintext;

pub mod naive_knn_plaintext;
47 changes: 47 additions & 0 deletions iris-mpc-cpu/src/hawkers/naive_knn_plaintext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use iris_mpc_common::iris_db::iris::IrisCode;
use rayon::{
iter::{IntoParallelIterator, ParallelIterator},
ThreadPool,
};
use serde::{Deserialize, Serialize};

use crate::hawkers::plaintext_store::fraction_ordering;

#[derive(Serialize, Deserialize)]
pub struct KNNResult {
pub node: usize,
neighbors: Vec<usize>,
}

pub fn naive_knn(
irises: &[IrisCode],
k: usize,
start: usize,
end: usize,
pool: &ThreadPool,
) -> Vec<KNNResult> {
pool.install(|| {
(start..end)
.collect::<Vec<_>>()
.into_par_iter()
.map(|i| {
let current_iris = &irises[i - 1];
let mut neighbors = irises
.iter()
.enumerate()
.flat_map(|(j, other_iris)| {
(i != j + 1)
.then_some((j + 1, current_iris.get_distance_fraction(other_iris)))
})
.collect::<Vec<_>>();
neighbors
.select_nth_unstable_by(k - 1, |lhs, rhs| fraction_ordering(&lhs.1, &rhs.1));
let mut neighbors = neighbors.drain(0..k).collect::<Vec<_>>();
neighbors.shrink_to_fit(); // just to make sure
neighbors.sort_by(|lhs, rhs| fraction_ordering(&lhs.1, &rhs.1));
let neighbors = neighbors.into_iter().map(|(i, _)| i).collect::<Vec<_>>();
KNNResult { node: i, neighbors }
})
.collect::<Vec<_>>()
})
}
8 changes: 7 additions & 1 deletion iris-mpc-cpu/src/hawkers/plaintext_store.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{cmp::Ordering, sync::Arc};

use crate::{
hawkers::shared_irises::{SharedIrises, SharedIrisesRef},
Expand Down Expand Up @@ -124,6 +124,12 @@ fn fraction_less_than(dist_1: &(u16, u16), dist_2: &(u16, u16)) -> bool {
(a as u32) * (d as u32) < (b as u32) * (c as u32)
}

pub fn fraction_ordering(dist_1: &(u16, u16), dist_2: &(u16, u16)) -> Ordering {
let (a, b) = *dist_1; // a/b
let (c, d) = *dist_2; // c/d
((a as u32) * (d as u32)).cmp(&((b as u32) * (c as u32)))
}

impl VectorStore for PlaintextStore {
type QueryRef = Arc<IrisCode>;
type VectorRef = VectorId;
Expand Down
Loading