@@ -3,10 +3,17 @@ use crate::{
3
3
hawkers:: shared_irises:: { SharedIrises , SharedIrisesRef } ,
4
4
hnsw:: { vector_store:: VectorStoreMut , VectorStore } ,
5
5
protocol:: {
6
- ops:: { batch_signed_lift_vec, cross_compare, galois_ring_to_rep3, lte_threshold_and_open} ,
6
+ ops:: {
7
+ batch_signed_lift_vec, conditionally_swap_distances,
8
+ conditionally_swap_distances_plain_ids, cross_compare, galois_ring_to_rep3,
9
+ lte_threshold_and_open, oblivious_cross_compare,
10
+ } ,
7
11
shared_iris:: { ArcIris , GaloisRingSharedIris } ,
8
12
} ,
9
- shares:: share:: { DistanceShare , Share } ,
13
+ shares:: {
14
+ bit:: Bit ,
15
+ share:: { DistanceShare , Share } ,
16
+ } ,
10
17
} ;
11
18
use eyre:: Result ;
12
19
use iris_mpc_common:: vector_id:: VectorId ;
@@ -55,6 +62,7 @@ impl Aby3Query {
55
62
}
56
63
57
64
pub type Aby3VectorRef = <Aby3Store as VectorStore >:: VectorRef ;
65
+ pub type Aby3DistanceRef = <Aby3Store as VectorStore >:: DistanceRef ;
58
66
59
67
pub type Aby3SharedIrises = SharedIrises < ArcIris > ;
60
68
pub type Aby3SharedIrisesRef = SharedIrisesRef < ArcIris > ;
@@ -122,6 +130,48 @@ impl Aby3Store {
122
130
pub async fn checksum ( & self ) -> u64 {
123
131
self . storage . checksum ( ) . await
124
132
}
133
+
134
+ /// Obliviously swaps the elements in `list` at the given `indices` according to the `swap_bits`.
135
+ /// If bit is 0, the elements are swapped, otherwise they are left unchanged.
136
+ /// Note that unchanged elements of the list are propagated as secret-shares.
137
+ pub async fn oblivious_swap_batch_plain_ids (
138
+ & mut self ,
139
+ swap_bits : Vec < Share < Bit > > ,
140
+ list : & [ ( u32 , Aby3DistanceRef ) ] ,
141
+ indices : & [ ( usize , usize ) ] ,
142
+ ) -> Result < Vec < ( Share < u32 > , Aby3DistanceRef ) > > {
143
+ if list. is_empty ( ) {
144
+ return Ok ( vec ! [ ] ) ;
145
+ }
146
+
147
+ conditionally_swap_distances_plain_ids ( & mut self . session , swap_bits, list, indices) . await
148
+ }
149
+
150
+ /// Obliviously compares pairs of distances in batch and returns a secret shared bit a < b for each pair.
151
+ pub async fn oblivious_less_than_batch (
152
+ & mut self ,
153
+ distances : & [ ( Aby3DistanceRef , Aby3DistanceRef ) ] ,
154
+ ) -> Result < Vec < Share < Bit > > > {
155
+ if distances. is_empty ( ) {
156
+ return Ok ( vec ! [ ] ) ;
157
+ }
158
+ oblivious_cross_compare ( & mut self . session , distances) . await
159
+ }
160
+
161
+ /// Obliviously swaps the elements in `list` at the given `indices` according to the `swap_bits`.
162
+ /// If bit is 0, the elements are swapped, otherwise they are left unchanged.
163
+ pub async fn oblivious_swap_batch (
164
+ & mut self ,
165
+ swap_bits : Vec < Share < Bit > > ,
166
+ list : & [ ( Share < u32 > , Aby3DistanceRef ) ] ,
167
+ indices : & [ ( usize , usize ) ] ,
168
+ ) -> Result < Vec < ( Share < u32 > , Aby3DistanceRef ) > > {
169
+ if list. is_empty ( ) {
170
+ return Ok ( vec ! [ ] ) ;
171
+ }
172
+
173
+ conditionally_swap_distances ( & mut self . session , swap_bits, list, indices) . await
174
+ }
125
175
}
126
176
127
177
impl VectorStore for Aby3Store {
@@ -251,7 +301,7 @@ mod tests {
251
301
252
302
use super :: * ;
253
303
use crate :: {
254
- execution:: hawk_main:: scheduler:: parallelize,
304
+ execution:: { hawk_main:: scheduler:: parallelize, session :: SessionHandles } ,
255
305
hawkers:: {
256
306
aby3:: test_utils:: {
257
307
eval_vector_distance, get_owner_index, lazy_random_setup,
@@ -527,6 +577,91 @@ mod tests {
527
577
Ok ( ( ) )
528
578
}
529
579
580
+ #[ tokio:: test( flavor = "multi_thread" ) ]
581
+ #[ traced_test]
582
+ async fn test_oblivious_swap ( ) -> Result < ( ) > {
583
+ let list_len = 6_u32 ;
584
+ let plain_list = ( 0 ..list_len)
585
+ . map ( |i| ( VectorId :: from_0_index ( i) , ( i, i) ) )
586
+ . collect_vec ( ) ;
587
+ let swap_bits_for_plain = vec ! [ true , false ] ;
588
+ let indices_for_plain = vec ! [ ( 0 , 1 ) , ( 4 , 5 ) ] ;
589
+ let swap_bits_for_secret = vec ! [ true , false , false ] ;
590
+ let indices_for_secret = vec ! [ ( 1 , 2 ) , ( 0 , 4 ) , ( 3 , 5 ) ] ;
591
+
592
+ let mut local_stores = setup_local_store_aby3_players ( NetworkType :: Local ) . await ?;
593
+ let mut jobs = JoinSet :: new ( ) ;
594
+ for store in local_stores. iter_mut ( ) {
595
+ let store = store. clone ( ) ;
596
+ let swap_bits_for_plain = swap_bits_for_plain. clone ( ) ;
597
+ let swap_bits_for_secret = swap_bits_for_secret. clone ( ) ;
598
+ let plain_list = plain_list. clone ( ) ;
599
+ let indices_for_plain = indices_for_plain. clone ( ) ;
600
+ let indices_for_secret = indices_for_secret. clone ( ) ;
601
+ jobs. spawn ( async move {
602
+ let mut store_lock = store. lock ( ) . await ;
603
+ let role = store_lock. session . own_role ( ) ;
604
+ let swap_bits1 = swap_bits_for_plain
605
+ . iter ( )
606
+ . map ( |b| Share :: from_const ( Bit :: new ( * b) , role) )
607
+ . collect_vec ( ) ;
608
+ let swap_bits2 = swap_bits_for_secret
609
+ . iter ( )
610
+ . map ( |b| Share :: from_const ( Bit :: new ( * b) , role) )
611
+ . collect_vec ( ) ;
612
+ let list = plain_list
613
+ . iter ( )
614
+ . map ( |( v, d) | {
615
+ (
616
+ v. index ( ) ,
617
+ DistanceShare :: new (
618
+ Share :: from_const ( d. 0 , role) ,
619
+ Share :: from_const ( d. 1 , role) ,
620
+ ) ,
621
+ )
622
+ } )
623
+ . collect_vec ( ) ;
624
+ let tmp_list = store_lock
625
+ . oblivious_swap_batch_plain_ids ( swap_bits1, & list, & indices_for_plain)
626
+ . await ?;
627
+ store_lock
628
+ . oblivious_swap_batch ( swap_bits2, & tmp_list, & indices_for_secret)
629
+ . await
630
+ } ) ;
631
+ }
632
+ let res = jobs
633
+ . join_all ( )
634
+ . await
635
+ . into_iter ( )
636
+ . collect :: < Result < Vec < _ > > > ( ) ?;
637
+ let mut expected_list = plain_list. clone ( ) ;
638
+ expected_list. swap ( 4 , 5 ) ;
639
+ expected_list. swap ( 0 , 4 ) ;
640
+ expected_list. swap ( 3 , 5 ) ;
641
+
642
+ for ( i, exp) in expected_list. iter ( ) . enumerate ( ) {
643
+ let id = ( res[ 0 ] [ i] . clone ( ) . 0 + & res[ 1 ] [ i] . 0 + & res[ 2 ] [ i] . 0 )
644
+ . get_a ( )
645
+ . convert ( ) ;
646
+ assert_eq ! ( id, exp. 0 . index( ) ) ;
647
+
648
+ let distance = {
649
+ let code_dot =
650
+ ( res[ 0 ] [ i] . clone ( ) . 1 . code_dot + & res[ 1 ] [ i] . 1 . code_dot + & res[ 2 ] [ i] . 1 . code_dot )
651
+ . get_a ( )
652
+ . convert ( ) ;
653
+ let mask_dot =
654
+ ( res[ 0 ] [ i] . clone ( ) . 1 . mask_dot + & res[ 1 ] [ i] . 1 . mask_dot + & res[ 2 ] [ i] . 1 . mask_dot )
655
+ . get_a ( )
656
+ . convert ( ) ;
657
+ ( code_dot, mask_dot)
658
+ } ;
659
+ assert_eq ! ( distance, exp. 1 ) ;
660
+ }
661
+
662
+ Ok ( ( ) )
663
+ }
664
+
530
665
#[ tokio:: test( flavor = "multi_thread" ) ]
531
666
#[ traced_test]
532
667
async fn test_gr_aby3_store_plaintext_batch ( ) -> Result < ( ) > {
0 commit comments