Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions iris-mpc-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ num-traits.workspace = true
prost = "0.13"
rand.workspace = true
rand_distr = "0.4.3"
rayon.workspace = true
rstest = "0.23.0"
serde.workspace = true
serde_json.workspace = true
Expand Down Expand Up @@ -128,3 +129,7 @@ path = "bin/graph_mem_cli.rs"
[[bin]]
name = "local_hnsw"
path = "bin/local_hnsw.rs"

[[bin]]
name = "generate_ideal_neighborhoods"
path = "bin/generate_ideal_neighborhoods.rs"
65 changes: 65 additions & 0 deletions iris-mpc-cpu/bin/generate_ideal_neighborhoods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use std::{fs::File, io::BufReader, path::PathBuf};

use clap::Parser;
use iris_mpc_common::iris_db::iris::IrisCode;
use iris_mpc_common::vector_id::SerialId;
use iris_mpc_cpu::{
hawkers::naive_knn_plaintext::naive_knn,
py_bindings::{limited_iterator, plaintext_store::Base64IrisCode},
};
use metrics::IntoF64;
use serde_json::Deserializer;
use std::time::Instant;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct IrisCodeWithSerialId {
pub iris_code: IrisCode,
pub serial_id: SerialId,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Number of irises to process
#[arg(long, default_value_t = 1000)]
num_irises: usize,

/// Number of threads to use
#[arg(long, default_value_t = 1)]
num_threads: usize,
}
#[tokio::main]
async fn main() {
let args = Args::parse();
let n_existing_irises = 0;
let num_irises = args.num_irises;

let mut path_to_iris_codes = PathBuf::new();
path_to_iris_codes.push("iris-mpc-cpu/data/store.ndjson".to_owned());

let file = File::open(path_to_iris_codes.as_path()).unwrap();
let reader = BufReader::new(file);

let stream = Deserializer::from_reader(reader)
.into_iter::<Base64IrisCode>()
.skip(2 * n_existing_irises);

let mut irises: [Vec<IrisCode>; 2] = [Vec::new(), Vec::new()];

let stream = limited_iterator(stream, Some(num_irises * 2));
for (idx, json_pt) in stream.enumerate() {
let iris_code_query = (&json_pt.unwrap()).into();
let _serial_id = ((idx / 2) + 1 + n_existing_irises) as u32;

let side = idx % 2;
irises[side].push(iris_code_query);
}

let start = Instant::now();
naive_knn(irises[0].clone(), args.num_threads);
let duration = start.elapsed();
println!(
"naive_knn took {:?} (per number of pairs)",
duration.into_f64() / (num_irises as f64) / (num_irises as f64)
);
}
2 changes: 2 additions & 0 deletions iris-mpc-cpu/src/hawkers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ pub mod plaintext_store;
pub mod shared_irises;

pub mod build_plaintext;

pub mod naive_knn_plaintext;
34 changes: 34 additions & 0 deletions iris-mpc-cpu/src/hawkers/naive_knn_plaintext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use iris_mpc_common::iris_db::iris::IrisCode;
use rayon::{
iter::{IntoParallelIterator, ParallelIterator},
ThreadPoolBuilder,
};

pub fn naive_knn(irises: Vec<IrisCode>, num_threads: usize) {
let k = 320;
let n = irises.len();

let pool = ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.unwrap();

let _results = pool.install(|| {
(0..n)
.collect::<Vec<_>>()
.into_par_iter()
.map(|i| {
let current_iris = &irises[i];
let mut distances = irises
.iter()
.enumerate()
.map(|(j, other_iris)| (j, current_iris.get_distance(other_iris)))
.collect::<Vec<_>>();
distances.select_nth_unstable_by(k - 1, |(_, d1), (_, d2)| d1.total_cmp(d2));
distances.truncate(k);
distances.sort_by(|(_, d1), (_, d2)| d1.total_cmp(d2));
distances
})
.collect::<Vec<_>>()
});
}
Loading