Skip to content

Commit a5c876f

Browse files
enzodimariaagnesLeroy
authored andcommitted
refactor(gpu): creating CudaScalarDivisorFFI for storing decomposed scalars and their metadata
1 parent 2d8ea2d commit a5c876f

File tree

8 files changed

+1174
-1668
lines changed

8 files changed

+1174
-1668
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,34 @@ typedef struct {
4848
uint32_t lwe_dimension;
4949
} CudaRadixCiphertextFFI;
5050

51+
typedef struct {
52+
uint64_t const *chosen_multiplier_has_at_least_one_set;
53+
uint64_t const *decomposed_chosen_multiplier;
54+
55+
uint32_t const num_scalars;
56+
uint32_t const active_bits;
57+
uint64_t const shift_pre;
58+
uint32_t const shift_post;
59+
uint32_t const ilog2_chosen_multiplier;
60+
uint32_t const chosen_multiplier_num_bits;
61+
62+
bool const is_chosen_multiplier_zero;
63+
bool const is_abs_chosen_multiplier_one;
64+
bool const is_chosen_multiplier_negative;
65+
bool const is_chosen_multiplier_pow2;
66+
bool const chosen_multiplier_has_more_bits_than_numerator;
67+
// if signed: test if chosen_multiplier >= 2^{num_bits - 1}
68+
bool const is_chosen_multiplier_geq_two_pow_numerator;
69+
70+
uint32_t const ilog2_divisor;
71+
72+
bool const is_divisor_zero;
73+
bool const is_abs_divisor_one;
74+
bool const is_divisor_negative;
75+
bool const is_divisor_pow2;
76+
bool const divisor_has_more_bits_than_numerator;
77+
} CudaScalarDivisorFFI;
78+
5179
uint64_t scratch_cuda_apply_univariate_lut_kb_64(
5280
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
5381
int8_t **mem_ptr, void const *input_lut, uint32_t lwe_dimension,
@@ -600,19 +628,15 @@ uint64_t scratch_cuda_integer_unsigned_scalar_div_radix_kb_64(
600628
uint32_t lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
601629
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
602630
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
603-
PBS_TYPE pbs_type, bool allocate_gpu_memory, bool is_divisor_power_of_two,
604-
bool log2_divisor_exceeds_threshold, bool multiplier_exceeds_threshold,
605-
uint32_t num_scalar_bits, uint32_t ilog2_divisor, bool allocate_ms_array);
631+
PBS_TYPE pbs_type, const CudaScalarDivisorFFI *scalar_divisor_ffi,
632+
bool allocate_gpu_memory, bool allocate_ms_array);
606633

607634
void cuda_integer_unsigned_scalar_div_radix_kb_64(
608635
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
609-
CudaRadixCiphertextFFI *numerator_ct, int8_t *mem_ptr, void *const *ksks,
610-
uint64_t const *decomposed_scalar, uint64_t const *has_at_least_one_set,
636+
CudaRadixCiphertextFFI *numerator_ct, int8_t *mem_ptr, void *const *bsks,
637+
void *const *ksks,
611638
const CudaModulusSwitchNoiseReductionKeyFFI *ms_noise_reduction_key,
612-
void *const *bsks, uint32_t num_scalars, bool multiplier_exceeds_threshold,
613-
bool is_divisor_power_of_two, bool log2_divisor_exceeds_threshold,
614-
uint32_t ilog2_divisor, uint64_t shift_pre, uint32_t shift_post,
615-
uint64_t rhs);
639+
const CudaScalarDivisorFFI *scalar_divisor_ffi);
616640

617641
void cleanup_cuda_integer_unsigned_scalar_div_radix_kb_64(
618642
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
@@ -644,23 +668,16 @@ uint64_t scratch_cuda_integer_signed_scalar_div_radix_kb_64(
644668
int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size,
645669
uint32_t lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
646670
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
647-
uint32_t num_blocks, uint32_t num_scalar_bits, uint32_t message_modulus,
648-
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
649-
bool is_absolute_divisor_one, bool is_divisor_negative,
650-
bool l_exceed_threshold, bool is_power_of_two, bool multiplier_is_small,
651-
bool allocate_ms_array);
671+
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
672+
PBS_TYPE pbs_type, const CudaScalarDivisorFFI *scalar_divisor_ffi,
673+
bool allocate_gpu_memory, bool allocate_ms_array);
652674

653675
void cuda_integer_signed_scalar_div_radix_kb_64(
654676
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
655-
CudaRadixCiphertextFFI *numerator_ct, int8_t *mem_ptr, void *const *ksks,
656-
void *const *bsks,
677+
CudaRadixCiphertextFFI *numerator_ct, int8_t *mem_ptr, void *const *bsks,
678+
void *const *ksks,
657679
const CudaModulusSwitchNoiseReductionKeyFFI *ms_noise_reduction_key,
658-
bool is_absolute_divisor_one, bool is_divisor_negative,
659-
bool l_exceed_threshold, bool is_power_of_two, bool multiplier_is_small,
660-
uint32_t l, uint32_t shift_post, bool is_rhs_power_of_two, bool is_rhs_zero,
661-
bool is_rhs_one, uint32_t rhs_shift, uint32_t numerator_bits,
662-
uint32_t num_scalars, uint64_t const *decomposed_scalar,
663-
uint64_t const *has_at_least_one_set);
680+
const CudaScalarDivisorFFI *scalar_divisor_ffi, uint32_t numerator_bits);
664681

665682
void cleanup_cuda_integer_signed_scalar_div_radix_kb_64(
666683
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
@@ -672,24 +689,18 @@ uint64_t scratch_integer_unsigned_scalar_div_rem_radix_kb_64(
672689
uint32_t lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
673690
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
674691
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
675-
PBS_TYPE pbs_type, bool allocate_gpu_memory, bool is_divisor_power_of_two,
676-
bool log2_divisor_exceeds_threshold, bool multiplier_exceeds_threshold,
677-
uint32_t num_scalar_bits_for_div, uint32_t num_scalar_bits_for_mul,
678-
uint32_t ilog2_divisor, uint64_t divisor, bool allocate_ms_array);
692+
PBS_TYPE pbs_type, const CudaScalarDivisorFFI *scalar_divisor_ffi,
693+
uint32_t const active_bits_divisor, bool allocate_gpu_memory,
694+
bool allocate_ms_array);
679695

680696
void cuda_integer_unsigned_scalar_div_rem_radix_kb_64(
681697
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
682698
CudaRadixCiphertextFFI *quotient_ct, CudaRadixCiphertextFFI *remainder_ct,
683-
int8_t *mem_ptr, void *const *ksks, void *const *bsks,
684-
uint64_t const *decomposed_scalar_for_div,
685-
uint64_t const *decomposed_scalar_for_mul,
686-
uint64_t const *has_at_least_one_set_for_div,
687-
uint64_t const *has_at_least_one_set_for_mul,
699+
int8_t *mem_ptr, void *const *bsks, void *const *ksks,
688700
const CudaModulusSwitchNoiseReductionKeyFFI *ms_noise_reduction_key,
689-
uint32_t num_scalars_for_div, uint32_t num_scalars_for_mul,
690-
bool multiplier_exceeds_threshold, bool is_divisor_power_of_two,
691-
bool log2_divisor_exceeds_threshold, uint32_t ilog2_divisor,
692-
uint64_t divisor, uint64_t shift_pre, uint32_t shift_post, uint64_t rhs,
701+
const CudaScalarDivisorFFI *scalar_divisor_ffi,
702+
uint64_t const *divisor_has_at_least_one_set,
703+
uint64_t const *decomposed_divisor, uint32_t const num_scalars_divisor,
693704
void const *clear_blocks, void const *h_clear_blocks,
694705
uint32_t num_clear_blocks);
695706

@@ -703,27 +714,19 @@ uint64_t scratch_integer_signed_scalar_div_rem_radix_kb_64(
703714
uint32_t lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
704715
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
705716
uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus,
706-
PBS_TYPE pbs_type, bool allocate_gpu_memory,
707-
uint32_t num_scalar_bits_for_div, uint32_t num_scalar_bits_for_mul,
708-
bool is_absolute_divisor_one, bool is_divisor_negative,
709-
bool l_exceed_threshold, bool is_absolute_divisor_power_of_two,
710-
bool is_divisor_zero, bool multiplier_is_small, bool allocate_ms_array);
717+
PBS_TYPE pbs_type, const CudaScalarDivisorFFI *scalar_divisor_ffi,
718+
uint32_t const active_bits_divisor, bool allocate_gpu_memory,
719+
bool allocate_ms_array);
711720

712721
void cuda_integer_signed_scalar_div_rem_radix_kb_64(
713722
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
714723
CudaRadixCiphertextFFI *quotient_ct, CudaRadixCiphertextFFI *remainder_ct,
715-
int8_t *mem_ptr, void *const *ksks, void *const *bsks,
724+
int8_t *mem_ptr, void *const *bsks, void *const *ksks,
716725
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key,
717-
bool is_absolute_divisor_one, bool is_divisor_negative,
718-
bool is_divisor_zero, bool l_exceed_threshold,
719-
bool is_absolute_divisor_power_of_two, bool multiplier_is_small, uint32_t l,
720-
uint32_t shift_post, bool is_rhs_power_of_two, bool is_rhs_zero,
721-
bool is_rhs_one, uint32_t rhs_shift, uint32_t divisor_shift,
722-
uint32_t numerator_bits, uint32_t num_scalars_for_div,
723-
uint32_t num_scalars_for_mul, uint64_t const *decomposed_scalar_for_div,
724-
uint64_t const *decomposed_scalar_for_mul,
725-
uint64_t const *has_at_least_one_set_for_div,
726-
uint64_t const *has_at_least_one_set_for_mul);
726+
const CudaScalarDivisorFFI *scalar_divisor_ffi,
727+
uint64_t const *divisor_has_at_least_one_set,
728+
uint64_t const *decomposed_divisor, uint32_t const num_scalars_divisor,
729+
uint32_t numerator_bits);
727730

728731
void cleanup_cuda_integer_signed_scalar_div_rem_radix_kb_64(
729732
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,

0 commit comments

Comments
 (0)