@@ -48,6 +48,34 @@ typedef struct {
4848 uint32_t lwe_dimension ;
4949} CudaRadixCiphertextFFI ;
5050
51+ typedef struct {
52+ uint64_t const * chosen_multiplier_has_at_least_one_set ;
53+ uint64_t const * decomposed_chosen_multiplier ;
54+
55+ uint32_t const num_scalars ;
56+ uint32_t const active_bits ;
57+ uint64_t const shift_pre ;
58+ uint32_t const shift_post ;
59+ uint32_t const ilog2_chosen_multiplier ;
60+ uint32_t const chosen_multiplier_num_bits ;
61+
62+ bool const is_chosen_multiplier_zero ;
63+ bool const is_abs_chosen_multiplier_one ;
64+ bool const is_chosen_multiplier_negative ;
65+ bool const is_chosen_multiplier_pow2 ;
66+ bool const chosen_multiplier_has_more_bits_than_numerator ;
67+ // if signed: test if chosen_multiplier >= 2^{num_bits - 1}
68+ bool const is_chosen_multiplier_geq_two_pow_numerator ;
69+
70+ uint32_t const ilog2_divisor ;
71+
72+ bool const is_divisor_zero ;
73+ bool const is_abs_divisor_one ;
74+ bool const is_divisor_negative ;
75+ bool const is_divisor_pow2 ;
76+ bool const divisor_has_more_bits_than_numerator ;
77+ } CudaScalarDivisorFFI ;
78+
5179uint64_t scratch_cuda_apply_univariate_lut_kb_64 (
5280 void * const * streams , uint32_t const * gpu_indexes , uint32_t gpu_count ,
5381 int8_t * * mem_ptr , void const * input_lut , uint32_t lwe_dimension ,
@@ -600,19 +628,15 @@ uint64_t scratch_cuda_integer_unsigned_scalar_div_radix_kb_64(
600628 uint32_t lwe_dimension , uint32_t ks_level , uint32_t ks_base_log ,
601629 uint32_t pbs_level , uint32_t pbs_base_log , uint32_t grouping_factor ,
602630 uint32_t num_blocks , uint32_t message_modulus , uint32_t carry_modulus ,
603- PBS_TYPE pbs_type , bool allocate_gpu_memory , bool is_divisor_power_of_two ,
604- bool log2_divisor_exceeds_threshold , bool multiplier_exceeds_threshold ,
605- uint32_t num_scalar_bits , uint32_t ilog2_divisor , bool allocate_ms_array );
631+ PBS_TYPE pbs_type , const CudaScalarDivisorFFI * scalar_divisor_ffi ,
632+ bool allocate_gpu_memory , bool allocate_ms_array );
606633
607634void cuda_integer_unsigned_scalar_div_radix_kb_64 (
608635 void * const * streams , uint32_t const * gpu_indexes , uint32_t gpu_count ,
609- CudaRadixCiphertextFFI * numerator_ct , int8_t * mem_ptr , void * const * ksks ,
610- uint64_t const * decomposed_scalar , uint64_t const * has_at_least_one_set ,
636+ CudaRadixCiphertextFFI * numerator_ct , int8_t * mem_ptr , void * const * bsks ,
637+ void * const * ksks ,
611638 const CudaModulusSwitchNoiseReductionKeyFFI * ms_noise_reduction_key ,
612- void * const * bsks , uint32_t num_scalars , bool multiplier_exceeds_threshold ,
613- bool is_divisor_power_of_two , bool log2_divisor_exceeds_threshold ,
614- uint32_t ilog2_divisor , uint64_t shift_pre , uint32_t shift_post ,
615- uint64_t rhs );
639+ const CudaScalarDivisorFFI * scalar_divisor_ffi );
616640
617641void cleanup_cuda_integer_unsigned_scalar_div_radix_kb_64 (
618642 void * const * streams , uint32_t const * gpu_indexes , uint32_t gpu_count ,
@@ -644,23 +668,16 @@ uint64_t scratch_cuda_integer_signed_scalar_div_radix_kb_64(
644668 int8_t * * mem_ptr , uint32_t glwe_dimension , uint32_t polynomial_size ,
645669 uint32_t lwe_dimension , uint32_t ks_level , uint32_t ks_base_log ,
646670 uint32_t pbs_level , uint32_t pbs_base_log , uint32_t grouping_factor ,
647- uint32_t num_blocks , uint32_t num_scalar_bits , uint32_t message_modulus ,
648- uint32_t carry_modulus , PBS_TYPE pbs_type , bool allocate_gpu_memory ,
649- bool is_absolute_divisor_one , bool is_divisor_negative ,
650- bool l_exceed_threshold , bool is_power_of_two , bool multiplier_is_small ,
651- bool allocate_ms_array );
671+ uint32_t num_blocks , uint32_t message_modulus , uint32_t carry_modulus ,
672+ PBS_TYPE pbs_type , const CudaScalarDivisorFFI * scalar_divisor_ffi ,
673+ bool allocate_gpu_memory , bool allocate_ms_array );
652674
653675void cuda_integer_signed_scalar_div_radix_kb_64 (
654676 void * const * streams , uint32_t const * gpu_indexes , uint32_t gpu_count ,
655- CudaRadixCiphertextFFI * numerator_ct , int8_t * mem_ptr , void * const * ksks ,
656- void * const * bsks ,
677+ CudaRadixCiphertextFFI * numerator_ct , int8_t * mem_ptr , void * const * bsks ,
678+ void * const * ksks ,
657679 const CudaModulusSwitchNoiseReductionKeyFFI * ms_noise_reduction_key ,
658- bool is_absolute_divisor_one , bool is_divisor_negative ,
659- bool l_exceed_threshold , bool is_power_of_two , bool multiplier_is_small ,
660- uint32_t l , uint32_t shift_post , bool is_rhs_power_of_two , bool is_rhs_zero ,
661- bool is_rhs_one , uint32_t rhs_shift , uint32_t numerator_bits ,
662- uint32_t num_scalars , uint64_t const * decomposed_scalar ,
663- uint64_t const * has_at_least_one_set );
680+ const CudaScalarDivisorFFI * scalar_divisor_ffi , uint32_t numerator_bits );
664681
665682void cleanup_cuda_integer_signed_scalar_div_radix_kb_64 (
666683 void * const * streams , uint32_t const * gpu_indexes , uint32_t gpu_count ,
@@ -672,24 +689,18 @@ uint64_t scratch_integer_unsigned_scalar_div_rem_radix_kb_64(
672689 uint32_t lwe_dimension , uint32_t ks_level , uint32_t ks_base_log ,
673690 uint32_t pbs_level , uint32_t pbs_base_log , uint32_t grouping_factor ,
674691 uint32_t num_blocks , uint32_t message_modulus , uint32_t carry_modulus ,
675- PBS_TYPE pbs_type , bool allocate_gpu_memory , bool is_divisor_power_of_two ,
676- bool log2_divisor_exceeds_threshold , bool multiplier_exceeds_threshold ,
677- uint32_t num_scalar_bits_for_div , uint32_t num_scalar_bits_for_mul ,
678- uint32_t ilog2_divisor , uint64_t divisor , bool allocate_ms_array );
692+ PBS_TYPE pbs_type , const CudaScalarDivisorFFI * scalar_divisor_ffi ,
693+ uint32_t const active_bits_divisor , bool allocate_gpu_memory ,
694+ bool allocate_ms_array );
679695
680696void cuda_integer_unsigned_scalar_div_rem_radix_kb_64 (
681697 void * const * streams , uint32_t const * gpu_indexes , uint32_t gpu_count ,
682698 CudaRadixCiphertextFFI * quotient_ct , CudaRadixCiphertextFFI * remainder_ct ,
683- int8_t * mem_ptr , void * const * ksks , void * const * bsks ,
684- uint64_t const * decomposed_scalar_for_div ,
685- uint64_t const * decomposed_scalar_for_mul ,
686- uint64_t const * has_at_least_one_set_for_div ,
687- uint64_t const * has_at_least_one_set_for_mul ,
699+ int8_t * mem_ptr , void * const * bsks , void * const * ksks ,
688700 const CudaModulusSwitchNoiseReductionKeyFFI * ms_noise_reduction_key ,
689- uint32_t num_scalars_for_div , uint32_t num_scalars_for_mul ,
690- bool multiplier_exceeds_threshold , bool is_divisor_power_of_two ,
691- bool log2_divisor_exceeds_threshold , uint32_t ilog2_divisor ,
692- uint64_t divisor , uint64_t shift_pre , uint32_t shift_post , uint64_t rhs ,
701+ const CudaScalarDivisorFFI * scalar_divisor_ffi ,
702+ uint64_t const * divisor_has_at_least_one_set ,
703+ uint64_t const * decomposed_divisor , uint32_t const num_scalars_divisor ,
693704 void const * clear_blocks , void const * h_clear_blocks ,
694705 uint32_t num_clear_blocks );
695706
@@ -703,27 +714,19 @@ uint64_t scratch_integer_signed_scalar_div_rem_radix_kb_64(
703714 uint32_t lwe_dimension , uint32_t ks_level , uint32_t ks_base_log ,
704715 uint32_t pbs_level , uint32_t pbs_base_log , uint32_t grouping_factor ,
705716 uint32_t num_blocks , uint32_t message_modulus , uint32_t carry_modulus ,
706- PBS_TYPE pbs_type , bool allocate_gpu_memory ,
707- uint32_t num_scalar_bits_for_div , uint32_t num_scalar_bits_for_mul ,
708- bool is_absolute_divisor_one , bool is_divisor_negative ,
709- bool l_exceed_threshold , bool is_absolute_divisor_power_of_two ,
710- bool is_divisor_zero , bool multiplier_is_small , bool allocate_ms_array );
717+ PBS_TYPE pbs_type , const CudaScalarDivisorFFI * scalar_divisor_ffi ,
718+ uint32_t const active_bits_divisor , bool allocate_gpu_memory ,
719+ bool allocate_ms_array );
711720
712721void cuda_integer_signed_scalar_div_rem_radix_kb_64 (
713722 void * const * streams , uint32_t const * gpu_indexes , uint32_t gpu_count ,
714723 CudaRadixCiphertextFFI * quotient_ct , CudaRadixCiphertextFFI * remainder_ct ,
715- int8_t * mem_ptr , void * const * ksks , void * const * bsks ,
724+ int8_t * mem_ptr , void * const * bsks , void * const * ksks ,
716725 CudaModulusSwitchNoiseReductionKeyFFI const * ms_noise_reduction_key ,
717- bool is_absolute_divisor_one , bool is_divisor_negative ,
718- bool is_divisor_zero , bool l_exceed_threshold ,
719- bool is_absolute_divisor_power_of_two , bool multiplier_is_small , uint32_t l ,
720- uint32_t shift_post , bool is_rhs_power_of_two , bool is_rhs_zero ,
721- bool is_rhs_one , uint32_t rhs_shift , uint32_t divisor_shift ,
722- uint32_t numerator_bits , uint32_t num_scalars_for_div ,
723- uint32_t num_scalars_for_mul , uint64_t const * decomposed_scalar_for_div ,
724- uint64_t const * decomposed_scalar_for_mul ,
725- uint64_t const * has_at_least_one_set_for_div ,
726- uint64_t const * has_at_least_one_set_for_mul );
726+ const CudaScalarDivisorFFI * scalar_divisor_ffi ,
727+ uint64_t const * divisor_has_at_least_one_set ,
728+ uint64_t const * decomposed_divisor , uint32_t const num_scalars_divisor ,
729+ uint32_t numerator_bits );
727730
728731void cleanup_cuda_integer_signed_scalar_div_rem_radix_kb_64 (
729732 void * const * streams , uint32_t const * gpu_indexes , uint32_t gpu_count ,
0 commit comments