Skip to content

Commit 2f85ee9

Browse files
committed
wip
1 parent d12a2b3 commit 2f85ee9

File tree

6 files changed

+437
-4
lines changed

6 files changed

+437
-4
lines changed

Cargo.lock

Lines changed: 40 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ axum = "0.7"
3030
clap = { version = "4", features = ["derive", "env"] }
3131
base64 = "0.22.1"
3232
bytemuck = { version = "1.17", features = ["derive"] }
33+
bitvec = "1.0.1"
3334
dotenvy = "0.15"
3435
eyre = "0.6"
3536
futures = "0.3.30"

iris-mpc-common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ aws-sdk-sns = { workspace = true, optional = true }
2626
aws-sdk-sqs = { workspace = true, optional = true }
2727
aws-sdk-s3 = { workspace = true, optional = true }
2828
aws-sdk-secretsmanager = { workspace = true, optional = true }
29+
bitvec.workspace = true
2930
dotenvy.workspace = true
3031
clap.workspace = true
3132
rand.workspace = true

iris-mpc-common/src/iris_db/iris.rs

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
use crate::galois_engine::degree4::GaloisRingIrisCodeShare;
2+
use crate::IRIS_CODE_LENGTH;
3+
use crate::ROTATIONS;
24
use base64::{prelude::BASE64_STANDARD, Engine};
5+
use bitvec::array::BitArray;
6+
use bitvec::slice::BitSlice;
7+
use bitvec::vec::BitVec;
38
use eyre::bail;
49
use eyre::Result;
10+
use itertools::izip;
511
use rand::{
612
distributions::{Bernoulli, Distribution},
713
Rng,
@@ -183,6 +189,121 @@ impl Default for IrisCode {
183189
}
184190
}
185191

192+
pub struct RotExtIrisCodeArray(pub [BitVec; 16]);
193+
194+
impl From<IrisCodeArray> for RotExtIrisCodeArray {
195+
fn from(value: IrisCodeArray) -> Self {
196+
let code = value.bits().collect::<Vec<_>>();
197+
Self(
198+
code.chunks_exact(IrisCode::CODE_COLS * 4)
199+
.map(|chunk| {
200+
let mut extended = [false;
201+
IrisCode::CODE_COLS * 4 + 2 * IrisCode::ROTATIONS_PER_DIRECTION * 4];
202+
let left_len = IrisCode::ROTATIONS_PER_DIRECTION * 4;
203+
let chunk_len = IrisCode::CODE_COLS * 4;
204+
let right_len = IrisCode::ROTATIONS_PER_DIRECTION * 4;
205+
206+
extended[..left_len].copy_from_slice(&chunk[chunk_len - left_len..chunk_len]);
207+
extended[left_len..left_len + chunk_len].copy_from_slice(chunk);
208+
extended[left_len + chunk_len..].copy_from_slice(&chunk[..right_len]);
209+
let extended: BitVec = extended.into_iter().collect();
210+
extended
211+
})
212+
.collect::<Vec<_>>()
213+
.try_into()
214+
.unwrap(),
215+
)
216+
}
217+
}
218+
219+
impl RotExtIrisCodeArray {
220+
pub const CENTER_START: usize = 4 * IrisCode::ROTATIONS_PER_DIRECTION;
221+
pub const ROW_LEN: usize = 800;
222+
223+
fn get_rotation_slice(&self, by: isize) -> [&BitSlice; 16] {
224+
let start = (Self::CENTER_START as isize + 4 * by)
225+
.try_into()
226+
.expect("Invalid rotation delta");
227+
let ret_vec = self
228+
.0
229+
.iter()
230+
.map(|row| &row[start..start + Self::ROW_LEN])
231+
.collect::<Vec<_>>();
232+
ret_vec.try_into().unwrap()
233+
}
234+
}
235+
236+
pub struct RotExtIrisCode {
237+
code: RotExtIrisCodeArray,
238+
mask: RotExtIrisCodeArray,
239+
}
240+
241+
impl From<IrisCode> for RotExtIrisCode {
242+
fn from(value: IrisCode) -> Self {
243+
RotExtIrisCode {
244+
code: value.code.into(),
245+
mask: value.mask.into(),
246+
}
247+
}
248+
}
249+
250+
struct RotExtIrisCodeRef<'a> {
251+
code: [&'a BitSlice; 16],
252+
mask: [&'a BitSlice; 16],
253+
}
254+
255+
impl RotExtIrisCode {
256+
fn get_rotation<'a>(&'a self, by: isize) -> RotExtIrisCodeRef<'a> {
257+
RotExtIrisCodeRef {
258+
code: self.code.get_rotation_slice(by),
259+
mask: self.mask.get_rotation_slice(by),
260+
}
261+
}
262+
}
263+
264+
impl RotExtIrisCodeRef<'_> {
265+
fn get_distance_fraction(&self, other: &RotExtIrisCodeRef) -> (u16, u16) {
266+
izip!(self.code, self.mask, other.code, other.mask).fold(
267+
(0, 0),
268+
|(num, denom), (lhs_code, lhs_mask, rhs_code, rhs_mask)| {
269+
let combined_mask = lhs_mask.to_owned() & rhs_mask;
270+
let combined_mask_len = combined_mask.count_ones();
271+
272+
let combined_code = (lhs_code.to_owned() ^ rhs_code) & combined_mask;
273+
let code_distance = combined_code.count_ones();
274+
275+
(
276+
num + (code_distance as u16),
277+
denom + (combined_mask_len as u16),
278+
)
279+
},
280+
)
281+
}
282+
}
283+
284+
pub fn fraction_less_than(dist_1: &(u16, u16), dist_2: &(u16, u16)) -> bool {
285+
let (a, b) = *dist_1; // a/b
286+
let (c, d) = *dist_2; // c/d
287+
((a as u32) * (d as u32)) < ((b as u32) * (c as u32))
288+
}
289+
290+
impl RotExtIrisCode {
291+
pub fn min_fhe(&self, other: &RotExtIrisCode) -> (u16, u16) {
292+
let other_center = other.get_rotation(0);
293+
294+
let mut ret = (0, 1);
295+
for rot in -15..16 {
296+
let self_rotated = self.get_rotation(rot);
297+
let dist = self_rotated.get_distance_fraction(&other_center);
298+
if fraction_less_than(&dist, &ret) {
299+
ret = dist;
300+
}
301+
}
302+
303+
ret
304+
}
305+
}
306+
186307
impl IrisCode {
187308
pub const IRIS_CODE_SIZE: usize = IrisCodeArray::IRIS_CODE_SIZE;
188309
pub const CODE_COLS: usize = 200;

iris-mpc-cpu/bin/generate_ideal_neighborhoods.rs

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ use std::{
55
};
66

77
use clap::{Parser, ValueEnum};
8-
use iris_mpc_common::iris_db::iris::IrisCode;
8+
use iris_mpc_common::iris_db::iris::{IrisCode, RotExtIrisCode};
99
use iris_mpc_cpu::{
10-
hawkers::naive_knn_plaintext::{naive_knn, KNNResult},
10+
hawkers::naive_knn_plaintext::{
11+
naive_knn, naive_knn_min_fhd, naive_knn_min_fhd1, naive_knn_min_fhd2, KNNResult,
12+
},
1113
py_bindings::{limited_iterator, plaintext_store::Base64IrisCode},
1214
};
1315
use metrics::IntoF64;
@@ -23,6 +25,12 @@ enum IrisSelection {
2325
Odd,
2426
}
2527

28+
#[derive(Clone, Debug, ValueEnum, Copy, Serialize, Deserialize, PartialEq)]
29+
enum DistanceUsed {
30+
Normal,
31+
MinFHD,
32+
}
33+
2634
/// A struct to hold the metadata stored in the first line of the results file.
2735
#[derive(Serialize, Deserialize, PartialEq, Debug)]
2836
struct ResultsHeader {
@@ -57,6 +65,10 @@ struct Args {
5765
/// Selection of irises to process
5866
#[arg(long, value_enum, default_value_t = IrisSelection::All)]
5967
irises_selection: IrisSelection,
68+
69+
/// Selection of irises to process
70+
#[arg(long, value_enum, default_value_t = DistanceUsed::Normal)]
71+
distance_used: DistanceUsed,
6072
}
6173
#[tokio::main]
6274
async fn main() {
@@ -202,9 +214,42 @@ async fn main() {
202214
println!("Starting work at serial id: {}", start);
203215
let mut evaluated_pairs = 0usize;
204216

217+
let rot_ext_irises = irises
218+
.iter()
219+
.map(|iris| RotExtIrisCode::from(iris.clone()))
220+
.collect::<Vec<_>>();
221+
222+
let irises_with_rotations = irises
223+
.iter()
224+
.map(|iris| iris.all_rotations().try_into().unwrap())
225+
.collect::<Vec<_>>();
226+
227+
let self_rots: Vec<[_; 31]> = irises_with_rotations
228+
.iter()
229+
.map(|rotations: &[IrisCode; 31]| {
230+
rotations
231+
.iter()
232+
.map(|rotation| rotation.get_distance_fraction(&rotations[15]))
233+
.collect::<Vec<_>>()
234+
.try_into()
235+
.unwrap()
236+
})
237+
.collect::<Vec<_>>();
238+
205239
while start < num_irises {
206240
let end = (start + chunk_size).min(num_irises);
207-
let results = naive_knn(&irises, args.k, start, end, &pool);
241+
let results = match args.distance_used {
242+
DistanceUsed::Normal => naive_knn(&irises, args.k, start, end, &pool),
243+
DistanceUsed::MinFHD => naive_knn_min_fhd2(
244+
&irises_with_rotations,
245+
&irises,
246+
&self_rots,
247+
args.k,
248+
start,
249+
end,
250+
&pool,
251+
),
252+
};
208253
evaluated_pairs += (end - start) * num_irises;
209254

210255
let mut file = OpenOptions::new()

0 commit comments

Comments
 (0)