|
| 1 | +use std::arch::x86_64::__m256i; |
| 2 | +use std::arch::x86_64::_mm256_add_epi64; |
| 3 | +use std::arch::x86_64::_mm256_add_epi8; |
| 4 | +use std::arch::x86_64::_mm256_and_si256; |
| 5 | +use std::arch::x86_64::_mm256_loadu_si256; |
| 6 | +use std::arch::x86_64::_mm256_sad_epu8; |
| 7 | +use std::arch::x86_64::_mm256_set1_epi8; |
| 8 | +use std::arch::x86_64::_mm256_setr_epi8; |
| 9 | +use std::arch::x86_64::_mm256_setzero_si256; |
| 10 | +use std::arch::x86_64::_mm256_shuffle_epi8; |
| 11 | +use std::arch::x86_64::_mm256_srli_epi16; |
| 12 | +use std::arch::x86_64::_mm256_storeu_si256; |
| 13 | +use std::arch::x86_64::_mm256_xor_si256; |
| 14 | + |
1 | 15 | use crate::galois_engine::degree4::GaloisRingIrisCodeShare;
|
2 | 16 | use crate::IRIS_CODE_LENGTH;
|
3 | 17 | use crate::ROTATIONS;
|
@@ -400,6 +414,57 @@ impl IrisCode {
|
400 | 414 | (code_distance as u16, combined_mask_len as u16)
|
401 | 415 | }
|
402 | 416 |
|
| 417 | + /// An unsafe worker function to calculate Hamming distance using AVX2. |
| 418 | + /// It processes the 200 u64s in the IrisCodeArray in 50 chunks of 256 bits. |
| 419 | + /// |
| 420 | + /// SAFETY: This function MUST only be called after a runtime check confirms |
| 421 | + /// that the CPU supports AVX2. |
| 422 | + #[target_feature(enable = "avx2")] |
| 423 | + pub unsafe fn get_distance_fraction_avx2(&self, other: &Self) -> (u16, u16) { |
| 424 | + // Get pointers to the raw u64 arrays. |
| 425 | + let self_code_ptr = self.code.0.as_ptr() as *const __m256i; |
| 426 | + let other_code_ptr = other.code.0.as_ptr() as *const __m256i; |
| 427 | + let self_mask_ptr = self.mask.0.as_ptr() as *const __m256i; |
| 428 | + let other_mask_ptr = other.mask.0.as_ptr() as *const __m256i; |
| 429 | + |
| 430 | + let mut total_code_distance: u32 = 0; |
| 431 | + let mut total_mask_len: u32 = 0; |
| 432 | + |
| 433 | + // A temporary array to store vector results for scalar popcounting. |
| 434 | + let mut temp_storage: [u64; 4] = [0; 4]; |
| 435 | + let temp_storage_ptr = temp_storage.as_mut_ptr() as *mut __m256i; |
| 436 | + |
| 437 | + // Loop 50 times (200 u64s / 4 u64s per __m256i vector = 50 iterations). |
| 438 | + for i in 0..50 { |
| 439 | + // Load 256 bits (4 u64s) for each of the four arrays. |
| 440 | + let self_code_vec = _mm256_loadu_si256(self_code_ptr.add(i)); |
| 441 | + let other_code_vec = _mm256_loadu_si256(other_code_ptr.add(i)); |
| 442 | + let self_mask_vec = _mm256_loadu_si256(self_mask_ptr.add(i)); |
| 443 | + let other_mask_vec = _mm256_loadu_si256(other_mask_ptr.add(i)); |
| 444 | + |
| 445 | + // 1. Get combined_mask = self.mask & other.mask; |
| 446 | + let combined_mask_vec = _mm256_and_si256(self_mask_vec, other_mask_vec); |
| 447 | + |
| 448 | + // 2. Get combined_code = (self.code ^ other.code) & combined_mask; |
| 449 | + let xor_code_vec = _mm256_xor_si256(self_code_vec, other_code_vec); |
| 450 | + let combined_code_vec = _mm256_and_si256(xor_code_vec, combined_mask_vec); |
| 451 | + |
| 452 | + // 3. Store the vector results to memory and use the fast scalar `count_ones` (`popcnt`). |
| 453 | + _mm256_storeu_si256(temp_storage_ptr, combined_mask_vec); |
| 454 | + total_mask_len += temp_storage[0].count_ones() |
| 455 | + + temp_storage[1].count_ones() |
| 456 | + + temp_storage[2].count_ones() |
| 457 | + + temp_storage[3].count_ones(); |
| 458 | + |
| 459 | + _mm256_storeu_si256(temp_storage_ptr, combined_code_vec); |
| 460 | + total_code_distance += temp_storage[0].count_ones() |
| 461 | + + temp_storage[1].count_ones() |
| 462 | + + temp_storage[2].count_ones() |
| 463 | + + temp_storage[3].count_ones(); |
| 464 | + } |
| 465 | + |
| 466 | + (total_code_distance as u16, total_mask_len as u16) |
| 467 | + } |
403 | 468 | /// Return the fractional Hamming distance between two iris codes, represented
|
404 | 469 | /// as the `i16` dot product of associated masked-bit vectors and the `u16` size
|
405 | 470 | /// of the common unmasked region.
|
|
0 commit comments