Skip to content

Commit 8d54954

Browse files
committed
avx experiments
1 parent 78eb079 commit 8d54954

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

iris-mpc-common/src/iris_db/iris.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
use std::arch::x86_64::__m256i;
2+
use std::arch::x86_64::_mm256_add_epi64;
3+
use std::arch::x86_64::_mm256_add_epi8;
4+
use std::arch::x86_64::_mm256_and_si256;
5+
use std::arch::x86_64::_mm256_loadu_si256;
6+
use std::arch::x86_64::_mm256_sad_epu8;
7+
use std::arch::x86_64::_mm256_set1_epi8;
8+
use std::arch::x86_64::_mm256_setr_epi8;
9+
use std::arch::x86_64::_mm256_setzero_si256;
10+
use std::arch::x86_64::_mm256_shuffle_epi8;
11+
use std::arch::x86_64::_mm256_srli_epi16;
12+
use std::arch::x86_64::_mm256_storeu_si256;
13+
use std::arch::x86_64::_mm256_xor_si256;
14+
115
use crate::galois_engine::degree4::GaloisRingIrisCodeShare;
216
use crate::IRIS_CODE_LENGTH;
317
use crate::ROTATIONS;
@@ -400,6 +414,57 @@ impl IrisCode {
400414
(code_distance as u16, combined_mask_len as u16)
401415
}
402416

417+
/// An unsafe worker function to calculate Hamming distance using AVX2.
418+
/// It processes the 200 u64s in the IrisCodeArray in 50 chunks of 256 bits.
419+
///
420+
/// SAFETY: This function MUST only be called after a runtime check confirms
421+
/// that the CPU supports AVX2.
422+
#[target_feature(enable = "avx2")]
423+
pub unsafe fn get_distance_fraction_avx2(&self, other: &Self) -> (u16, u16) {
424+
// Get pointers to the raw u64 arrays.
425+
let self_code_ptr = self.code.0.as_ptr() as *const __m256i;
426+
let other_code_ptr = other.code.0.as_ptr() as *const __m256i;
427+
let self_mask_ptr = self.mask.0.as_ptr() as *const __m256i;
428+
let other_mask_ptr = other.mask.0.as_ptr() as *const __m256i;
429+
430+
let mut total_code_distance: u32 = 0;
431+
let mut total_mask_len: u32 = 0;
432+
433+
// A temporary array to store vector results for scalar popcounting.
434+
let mut temp_storage: [u64; 4] = [0; 4];
435+
let temp_storage_ptr = temp_storage.as_mut_ptr() as *mut __m256i;
436+
437+
// Loop 50 times (200 u64s / 4 u64s per __m256i vector = 50 iterations).
438+
for i in 0..50 {
439+
// Load 256 bits (4 u64s) for each of the four arrays.
440+
let self_code_vec = _mm256_loadu_si256(self_code_ptr.add(i));
441+
let other_code_vec = _mm256_loadu_si256(other_code_ptr.add(i));
442+
let self_mask_vec = _mm256_loadu_si256(self_mask_ptr.add(i));
443+
let other_mask_vec = _mm256_loadu_si256(other_mask_ptr.add(i));
444+
445+
// 1. Get combined_mask = self.mask & other.mask;
446+
let combined_mask_vec = _mm256_and_si256(self_mask_vec, other_mask_vec);
447+
448+
// 2. Get combined_code = (self.code ^ other.code) & combined_mask;
449+
let xor_code_vec = _mm256_xor_si256(self_code_vec, other_code_vec);
450+
let combined_code_vec = _mm256_and_si256(xor_code_vec, combined_mask_vec);
451+
452+
// 3. Store the vector results to memory and use the fast scalar `count_ones` (`popcnt`).
453+
_mm256_storeu_si256(temp_storage_ptr, combined_mask_vec);
454+
total_mask_len += temp_storage[0].count_ones()
455+
+ temp_storage[1].count_ones()
456+
+ temp_storage[2].count_ones()
457+
+ temp_storage[3].count_ones();
458+
459+
_mm256_storeu_si256(temp_storage_ptr, combined_code_vec);
460+
total_code_distance += temp_storage[0].count_ones()
461+
+ temp_storage[1].count_ones()
462+
+ temp_storage[2].count_ones()
463+
+ temp_storage[3].count_ones();
464+
}
465+
466+
(total_code_distance as u16, total_mask_len as u16)
467+
}
403468
/// Return the fractional Hamming distance between two iris codes, represented
404469
/// as the `i16` dot product of associated masked-bit vectors and the `u16` size
405470
/// of the common unmasked region.

0 commit comments

Comments
 (0)