From 58d3615ed54e6ea0efa97102bd50c87a9aae48b6 Mon Sep 17 00:00:00 2001 From: christopherngutierrez Date: Wed, 19 Feb 2025 10:37:30 -0800 Subject: [PATCH 1/5] refactored context parameters to use krns and current_rns --- kerngen/high_parser/types.py | 14 +++++++++----- kerngen/tests/test_kerngen.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/kerngen/high_parser/types.py b/kerngen/high_parser/types.py index ade46c5..a6aab48 100644 --- a/kerngen/high_parser/types.py +++ b/kerngen/high_parser/types.py @@ -158,7 +158,7 @@ class Context(BaseModel): @classmethod def from_string(cls, line: str): """Construct context from a string""" - scheme, poly_order, max_rns, *optionals = line.split() + scheme, poly_order, key_rns, current_rns, *optionals = line.split() optional_dict = OptionsDictParser.parse(optionals) int_poly_order = int(poly_order) if ( @@ -170,15 +170,19 @@ def from_string(cls, line: str): f"Poly order `{int_poly_order}` must be power of two >= {MIN_POLY_SIZE} and < {MAX_POLY_SIZE}" ) - int_max_rns = int(max_rns) - int_key_rns = int_max_rns - int_key_rns += optional_dict.pop("krns_delta") + int_key_rns = int(key_rns) + int_current_rns = int(current_rns) + + if int_key_rns <= int_current_rns: + raise ValueError( + f"Current RNS must be less than Key RNS: current_rns={current_rns}, key_rns={key_rns}" + ) return cls( scheme=scheme.upper(), poly_order=int_poly_order, - max_rns=int_max_rns, key_rns=int_key_rns, + max_rns=int_current_rns, **optional_dict, ) diff --git a/kerngen/tests/test_kerngen.py b/kerngen/tests/test_kerngen.py index 19dc69c..9aca9b0 100644 --- a/kerngen/tests/test_kerngen.py +++ b/kerngen/tests/test_kerngen.py @@ -47,7 +47,7 @@ def test_op(kerngen_path, gen_op_data): def test_missing_context(kerngen_path): """Test kerngen raises an exception when context is not the first line of input""" - input_string = "ADD a b c\nCONTEXT BGV 16384 4\n" + input_string = "ADD a b c\nCONTEXT BGV 16384 4 3\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -59,7 +59,7 @@ def test_missing_context(kerngen_path): def test_multiple_contexts(kerngen_path): """Test kerngen raises an exception when more than one context is given""" - input_string = "CONTEXT BGV 16384 4\nData a 2\nCONTEXT BGV 16384 4\n" + input_string = "CONTEXT BGV 16384 4 2\nData a 2\nCONTEXT BGV 16384 4 2\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -71,7 +71,7 @@ def test_multiple_contexts(kerngen_path): def test_context_options_without_key(kerngen_path): """Test kerngen raises an exception when more than one context is given""" - input_string = "CONTEXT BGV 16384 4 1\nData a 2\n" + input_string = "CONTEXT BGV 16384 3 2 1\nData a 2\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -86,7 +86,7 @@ def test_context_options_without_key(kerngen_path): def test_context_unsupported_options_variable(kerngen_path): """Test kerngen raises an exception when more than one context is given""" - input_string = "CONTEXT BGV 16384 4 test=3\nData a 2\n" + input_string = "CONTEXT BGV 16384 3 2 test=3\nData a 2\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -99,7 +99,7 @@ def test_context_unsupported_options_variable(kerngen_path): @pytest.mark.parametrize("invalid", [-1, 256, 0.1, "str"]) def test_context_option_invalid_values(kerngen_path, invalid): """Test kerngen raises an exception if value is out of range for correct key""" - input_string = f"CONTEXT BGV 16384 4 krns_delta={invalid}\nData a 2\n" + input_string = f"CONTEXT BGV 16384 3 2 krns_delta={invalid}\nData a 2\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -115,7 +115,7 @@ def test_context_option_invalid_values(kerngen_path, invalid): def test_unrecognised_opname(kerngen_path): """Test kerngen raises an exception when receiving an unrecognised opname""" - input_string = "CONTEXT BGV 16384 4\nOPERATION a b c\n" + input_string = "CONTEXT BGV 16384 3 2\nOPERATION a b c\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -129,7 +129,7 @@ def test_unrecognised_opname(kerngen_path): def test_invalid_scheme(kerngen_path): """Test kerngen raises an exception when receiving an invalid scheme""" - input_string = "CONTEXT SCHEME 16384 4\nADD a b c\n" + input_string = "CONTEXT SCHEME 16384 4 3\nADD a b c\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -142,7 +142,7 @@ def test_invalid_scheme(kerngen_path): @pytest.mark.parametrize("invalid_poly", [16000, 2**12, 2**13, 2**18]) def test_invalid_poly_order(kerngen_path, invalid_poly): """Poly order should be powers of two >= 2^14 and <= 2^17""" - input_string = "CONTEXT BGV " + str(invalid_poly) + " 4\nADD a b c\n" + input_string = "CONTEXT BGV " + str(invalid_poly) + " 4 3\nADD a b c\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -181,7 +181,7 @@ def test_parse_results_multiple_context(): def fixture_gen_op_data(request): """Given an op name, return both the input and expected output strings""" in_lines = ( - "CONTEXT BGV 16384 4", + "CONTEXT BGV 16384 4 3", "Data a 2", "Data b 2", "Data c 2", From c73ea6da5798ab91f64000ae87dbf6d53622641b Mon Sep 17 00:00:00 2001 From: christopherngutierrez Date: Wed, 19 Feb 2025 10:37:30 -0800 Subject: [PATCH 2/5] refactored context parameters to use krns and current_rns --- kerngen/high_parser/types.py | 14 +++++++++----- kerngen/tests/test_kerngen.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/kerngen/high_parser/types.py b/kerngen/high_parser/types.py index ade46c5..a6aab48 100644 --- a/kerngen/high_parser/types.py +++ b/kerngen/high_parser/types.py @@ -158,7 +158,7 @@ class Context(BaseModel): @classmethod def from_string(cls, line: str): """Construct context from a string""" - scheme, poly_order, max_rns, *optionals = line.split() + scheme, poly_order, key_rns, current_rns, *optionals = line.split() optional_dict = OptionsDictParser.parse(optionals) int_poly_order = int(poly_order) if ( @@ -170,15 +170,19 @@ def from_string(cls, line: str): f"Poly order `{int_poly_order}` must be power of two >= {MIN_POLY_SIZE} and < {MAX_POLY_SIZE}" ) - int_max_rns = int(max_rns) - int_key_rns = int_max_rns - int_key_rns += optional_dict.pop("krns_delta") + int_key_rns = int(key_rns) + int_current_rns = int(current_rns) + + if int_key_rns <= int_current_rns: + raise ValueError( + f"Current RNS must be less than Key RNS: current_rns={current_rns}, key_rns={key_rns}" + ) return cls( scheme=scheme.upper(), poly_order=int_poly_order, - max_rns=int_max_rns, key_rns=int_key_rns, + max_rns=int_current_rns, **optional_dict, ) diff --git a/kerngen/tests/test_kerngen.py b/kerngen/tests/test_kerngen.py index 19dc69c..9aca9b0 100644 --- a/kerngen/tests/test_kerngen.py +++ b/kerngen/tests/test_kerngen.py @@ -47,7 +47,7 @@ def test_op(kerngen_path, gen_op_data): def test_missing_context(kerngen_path): """Test kerngen raises an exception when context is not the first line of input""" - input_string = "ADD a b c\nCONTEXT BGV 16384 4\n" + input_string = "ADD a b c\nCONTEXT BGV 16384 4 3\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -59,7 +59,7 @@ def test_missing_context(kerngen_path): def test_multiple_contexts(kerngen_path): """Test kerngen raises an exception when more than one context is given""" - input_string = "CONTEXT BGV 16384 4\nData a 2\nCONTEXT BGV 16384 4\n" + input_string = "CONTEXT BGV 16384 4 2\nData a 2\nCONTEXT BGV 16384 4 2\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -71,7 +71,7 @@ def test_multiple_contexts(kerngen_path): def test_context_options_without_key(kerngen_path): """Test kerngen raises an exception when more than one context is given""" - input_string = "CONTEXT BGV 16384 4 1\nData a 2\n" + input_string = "CONTEXT BGV 16384 3 2 1\nData a 2\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -86,7 +86,7 @@ def test_context_options_without_key(kerngen_path): def test_context_unsupported_options_variable(kerngen_path): """Test kerngen raises an exception when more than one context is given""" - input_string = "CONTEXT BGV 16384 4 test=3\nData a 2\n" + input_string = "CONTEXT BGV 16384 3 2 test=3\nData a 2\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -99,7 +99,7 @@ def test_context_unsupported_options_variable(kerngen_path): @pytest.mark.parametrize("invalid", [-1, 256, 0.1, "str"]) def test_context_option_invalid_values(kerngen_path, invalid): """Test kerngen raises an exception if value is out of range for correct key""" - input_string = f"CONTEXT BGV 16384 4 krns_delta={invalid}\nData a 2\n" + input_string = f"CONTEXT BGV 16384 3 2 krns_delta={invalid}\nData a 2\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -115,7 +115,7 @@ def test_context_option_invalid_values(kerngen_path, invalid): def test_unrecognised_opname(kerngen_path): """Test kerngen raises an exception when receiving an unrecognised opname""" - input_string = "CONTEXT BGV 16384 4\nOPERATION a b c\n" + input_string = "CONTEXT BGV 16384 3 2\nOPERATION a b c\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -129,7 +129,7 @@ def test_unrecognised_opname(kerngen_path): def test_invalid_scheme(kerngen_path): """Test kerngen raises an exception when receiving an invalid scheme""" - input_string = "CONTEXT SCHEME 16384 4\nADD a b c\n" + input_string = "CONTEXT SCHEME 16384 4 3\nADD a b c\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -142,7 +142,7 @@ def test_invalid_scheme(kerngen_path): @pytest.mark.parametrize("invalid_poly", [16000, 2**12, 2**13, 2**18]) def test_invalid_poly_order(kerngen_path, invalid_poly): """Poly order should be powers of two >= 2^14 and <= 2^17""" - input_string = "CONTEXT BGV " + str(invalid_poly) + " 4\nADD a b c\n" + input_string = "CONTEXT BGV " + str(invalid_poly) + " 4 3\nADD a b c\n" result = execute_process( [kerngen_path], data_in=input_string, @@ -181,7 +181,7 @@ def test_parse_results_multiple_context(): def fixture_gen_op_data(request): """Given an op name, return both the input and expected output strings""" in_lines = ( - "CONTEXT BGV 16384 4", + "CONTEXT BGV 16384 4 3", "Data a 2", "Data b 2", "Data c 2", From 7db5a1746f30d630ff8deaac4a25b5bde1c660a1 Mon Sep 17 00:00:00 2001 From: christopherngutierrez Date: Wed, 19 Feb 2025 15:05:57 -0800 Subject: [PATCH 3/5] removed rns_delta and updated test cases --- kerngen/high_parser/options_handler.py | 6 ++---- kerngen/tests/test_kerngen.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/kerngen/high_parser/options_handler.py b/kerngen/high_parser/options_handler.py index af62753..fe5e27b 100644 --- a/kerngen/high_parser/options_handler.py +++ b/kerngen/high_parser/options_handler.py @@ -63,11 +63,9 @@ def __init__(self, int_min: int, int_max: int, default: int | None): class OptionsDictFactory(ABC): """Abstract class that creates OptionsDict objects""" - MAX_KRNS_DELTA = 128 MAX_DIGIT = 3 - MIN_KRNS_DELTA = MIN_DIGIT = 0 + MIN_DIGIT = 0 options = { - "krns_delta": OptionsIntBounds(MIN_KRNS_DELTA, MAX_KRNS_DELTA, 0), "num_digits": OptionsIntBounds(MIN_DIGIT, MAX_DIGIT, None), } @@ -134,6 +132,6 @@ def parse(options: list[str]): ).op_value except ValueError as err: raise ValueError( - f"Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '{option}'" + f"Options must be key/value pairs (e.g. num_digits=3): '{option}'" ) from err return output_dict diff --git a/kerngen/tests/test_kerngen.py b/kerngen/tests/test_kerngen.py index 9aca9b0..4abba3c 100644 --- a/kerngen/tests/test_kerngen.py +++ b/kerngen/tests/test_kerngen.py @@ -78,7 +78,7 @@ def test_context_options_without_key(kerngen_path): ) assert not result.stdout assert ( - "ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '1'" + "ValueError: Options must be key/value pairs (e.g. num_digits=3): '1'" in result.stderr ) assert result.returncode != 0 @@ -99,14 +99,14 @@ def test_context_unsupported_options_variable(kerngen_path): @pytest.mark.parametrize("invalid", [-1, 256, 0.1, "str"]) def test_context_option_invalid_values(kerngen_path, invalid): """Test kerngen raises an exception if value is out of range for correct key""" - input_string = f"CONTEXT BGV 16384 3 2 krns_delta={invalid}\nData a 2\n" + input_string = f"CONTEXT BGV 16384 3 2 num_digits={invalid}\nData a 2\n" result = execute_process( [kerngen_path], data_in=input_string, ) assert not result.stdout assert ( - f"ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): 'krns_delta={invalid}'" + f"ValueError: Options must be key/value pairs (e.g. num_digits=3): 'num_digits={invalid}'" in result.stderr ) assert result.returncode != 0 From ed8f73d147ec93c13ecc9c644fad0730f9dd624e Mon Sep 17 00:00:00 2001 From: christopherngutierrez Date: Tue, 25 Feb 2025 13:52:54 -0800 Subject: [PATCH 4/5] modified to enable relin after mod, currently fails testing --- kerngen/high_parser/types.py | 12 +++++++++--- kerngen/pisa_generators/basic.py | 4 +++- kerngen/pisa_generators/decomp.py | 2 +- kerngen/pisa_generators/mod.py | 2 ++ kerngen/pisa_generators/relin.py | 8 ++++++++ 5 files changed, 23 insertions(+), 5 deletions(-) diff --git a/kerngen/high_parser/types.py b/kerngen/high_parser/types.py index a6aab48..d618741 100644 --- a/kerngen/high_parser/types.py +++ b/kerngen/high_parser/types.py @@ -148,13 +148,17 @@ class EmptyLine(BaseModel): class Context(BaseModel): """Class representing a given context of the scheme""" + # required context params scheme: str poly_order: int # the N - max_rns: int + key_rns: int + current_rns: int # optional vars for context - key_rns: int | None num_digits: int | None + # calculated based on required params + max_rns: int + @classmethod def from_string(cls, line: str): """Construct context from a string""" @@ -172,6 +176,7 @@ def from_string(cls, line: str): int_key_rns = int(key_rns) int_current_rns = int(current_rns) + int_max_rns = int_key_rns - 1 if int_key_rns <= int_current_rns: raise ValueError( @@ -182,7 +187,8 @@ def from_string(cls, line: str): scheme=scheme.upper(), poly_order=int_poly_order, key_rns=int_key_rns, - max_rns=int_current_rns, + current_rns=int_current_rns, + max_rns=int_max_rns, **optional_dict, ) diff --git a/kerngen/pisa_generators/basic.py b/kerngen/pisa_generators/basic.py index 42e702b..0fdd138 100644 --- a/kerngen/pisa_generators/basic.py +++ b/kerngen/pisa_generators/basic.py @@ -352,6 +352,7 @@ def partial_op( last_q: int, ): """ "A helper function to perform partial operation, such as add/sub on last half (input1) to all of input0""" + return [ op( context.label, @@ -366,7 +367,8 @@ def partial_op( ) for part, q, unit in it.product( range(polys.input_remaining_rns.parts), - range(polys.input_remaining_rns.rns), + range(polys.input_remaining_rns.start_rns, polys.input_remaining_rns.rns), + # range(polys.input_remaining_rns.rns), range(context.units), ) ] diff --git a/kerngen/pisa_generators/decomp.py b/kerngen/pisa_generators/decomp.py index 691f75c..f826248 100644 --- a/kerngen/pisa_generators/decomp.py +++ b/kerngen/pisa_generators/decomp.py @@ -33,7 +33,7 @@ def to_pisa(self) -> list[PIsaOp]: r2 = Immediate(name="R2", rns=self.context.key_rns) ls: list[pisa_op] = [] - for input_rns_index in range(self.input0.rns): + for input_rns_index in range(self.input0.start_rns, self.input0.rns): ls.extend( pisa_op.Muli( self.context.label, diff --git a/kerngen/pisa_generators/mod.py b/kerngen/pisa_generators/mod.py index 6c7b920..6d10e25 100644 --- a/kerngen/pisa_generators/mod.py +++ b/kerngen/pisa_generators/mod.py @@ -39,6 +39,8 @@ def to_pisa(self) -> list[PIsaOp]: """Return the p-isa code to perform an mod switch down""" # Immediates last_q = self.input0.rns - 1 + self.input0.start_rns = (self.context.key_rns - 1) - self.context.current_rns + it = Immediate(name="it" + self.var_suffix) t = Immediate(name="t", rns=last_q) one, r2, iq = common_immediates( diff --git a/kerngen/pisa_generators/relin.py b/kerngen/pisa_generators/relin.py index 2244349..6ad684d 100644 --- a/kerngen/pisa_generators/relin.py +++ b/kerngen/pisa_generators/relin.py @@ -24,20 +24,28 @@ def to_pisa(self) -> list[PIsaOp]: supports number of digits equal to the RNS size""" self.output.parts = 2 self.input0.parts = 3 + self.input0.start_rns = (self.context.key_rns - 1) - self.context.current_rns + # self.input0.rns = self.context.current_rns relin_key = KeyPolys( "rlk", parts=2, rns=self.context.key_rns, digits=self.input0.rns ) mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns) + mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk) mul_by_rlk_modded_down.rns = self.input0.rns + # mul_by_rlk_modded_down.start_rns = (self.context.key_rns - 1) - self.context.current_rns + input_last_part, last_coeff, upto_last_coeffs = extract_last_part_polys( self.input0, self.context.key_rns ) + # input_last_part.start_rns = (self.context.key_rns - 1) - self.context.max_rns + # input_last_part.rns = self.context.current_rns add_original = Polys.from_polys(mul_by_rlk_modded_down) add_original.name = self.input0.name + # add_original.start_rns = (self.context.key_rns - 1) - self.context.max_rns return mixed_to_pisa_ops( Comment("Start of relin kernel"), From 5d450463df2e491b6628ea60a3d73f29f3f4c7a2 Mon Sep 17 00:00:00 2001 From: christopherngutierrez Date: Wed, 5 Mar 2025 16:28:41 -0800 Subject: [PATCH 5/5] Fixed issue with rotate kernel, removed dead code. --- kerngen/pisa_generators/basic.py | 1 - kerngen/pisa_generators/relin.py | 6 ------ kerngen/pisa_generators/rotate.py | 3 ++- kerngen/tests/test_kerngen.py | 8 ++++++-- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/kerngen/pisa_generators/basic.py b/kerngen/pisa_generators/basic.py index 0fdd138..3439941 100644 --- a/kerngen/pisa_generators/basic.py +++ b/kerngen/pisa_generators/basic.py @@ -368,7 +368,6 @@ def partial_op( for part, q, unit in it.product( range(polys.input_remaining_rns.parts), range(polys.input_remaining_rns.start_rns, polys.input_remaining_rns.rns), - # range(polys.input_remaining_rns.rns), range(context.units), ) ] diff --git a/kerngen/pisa_generators/relin.py b/kerngen/pisa_generators/relin.py index 6ad684d..0cae139 100644 --- a/kerngen/pisa_generators/relin.py +++ b/kerngen/pisa_generators/relin.py @@ -24,8 +24,6 @@ def to_pisa(self) -> list[PIsaOp]: supports number of digits equal to the RNS size""" self.output.parts = 2 self.input0.parts = 3 - self.input0.start_rns = (self.context.key_rns - 1) - self.context.current_rns - # self.input0.rns = self.context.current_rns relin_key = KeyPolys( "rlk", parts=2, rns=self.context.key_rns, digits=self.input0.rns @@ -35,17 +33,13 @@ def to_pisa(self) -> list[PIsaOp]: mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk) mul_by_rlk_modded_down.rns = self.input0.rns - # mul_by_rlk_modded_down.start_rns = (self.context.key_rns - 1) - self.context.current_rns input_last_part, last_coeff, upto_last_coeffs = extract_last_part_polys( self.input0, self.context.key_rns ) - # input_last_part.start_rns = (self.context.key_rns - 1) - self.context.max_rns - # input_last_part.rns = self.context.current_rns add_original = Polys.from_polys(mul_by_rlk_modded_down) add_original.name = self.input0.name - # add_original.start_rns = (self.context.key_rns - 1) - self.context.max_rns return mixed_to_pisa_ops( Comment("Start of relin kernel"), diff --git a/kerngen/pisa_generators/rotate.py b/kerngen/pisa_generators/rotate.py index b125a23..078f7fd 100644 --- a/kerngen/pisa_generators/rotate.py +++ b/kerngen/pisa_generators/rotate.py @@ -28,6 +28,7 @@ def to_pisa(self) -> list[PIsaOp]: supports number of digits equal to the RNS size""" self.output.parts = 2 self.input0.parts = 2 + relin_key = KeyPolys( "gk", parts=2, rns=self.context.key_rns, digits=self.input0.rns ) @@ -60,7 +61,7 @@ def to_pisa(self) -> list[PIsaOp]: Comment("Multiply by rotate key"), KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 1), Comment("Mod switch down to Q"), - Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk), + Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk, Mod.MOD_P), INTT(self.context, cd, start_input), NTT(self.context, cd, cd), Add(self.context, self.output, cd, first_part_rlk), diff --git a/kerngen/tests/test_kerngen.py b/kerngen/tests/test_kerngen.py index 4abba3c..bd401f7 100644 --- a/kerngen/tests/test_kerngen.py +++ b/kerngen/tests/test_kerngen.py @@ -168,8 +168,12 @@ def test_parse_results_multiple_context(): with pytest.raises(LookupError) as e: parse_results = ParseResults( [ - Context(scheme="BGV", poly_order=16384, max_rns=1), - Context(scheme="CKKS", poly_order=16384, max_rns=1), + Context( + scheme="BGV", poly_order=16384, key_rns=2, current_rns=1, max_rns=1 + ), + Context( + scheme="CKKS", poly_order=16384, key_rns=2, current_rns=1, max_rns=1 + ), ], {}, )