Skip to content

Commit 31b47b2

Browse files
Context Parameters Refactor - Key/Current RNS (#64)
* refactored context parameters to use krns and current_rns
1 parent 643c020 commit 31b47b2

File tree

8 files changed

+45
-27
lines changed

8 files changed

+45
-27
lines changed

kerngen/high_parser/options_handler.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,9 @@ def __init__(self, int_min: int, int_max: int, default: int | None):
6363
class OptionsDictFactory(ABC):
6464
"""Abstract class that creates OptionsDict objects"""
6565

66-
MAX_KRNS_DELTA = 128
6766
MAX_DIGIT = 3
68-
MIN_KRNS_DELTA = MIN_DIGIT = 0
67+
MIN_DIGIT = 0
6968
options = {
70-
"krns_delta": OptionsIntBounds(MIN_KRNS_DELTA, MAX_KRNS_DELTA, 0),
7169
"num_digits": OptionsIntBounds(MIN_DIGIT, MAX_DIGIT, None),
7270
}
7371

@@ -134,6 +132,6 @@ def parse(options: list[str]):
134132
).op_value
135133
except ValueError as err:
136134
raise ValueError(
137-
f"Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '{option}'"
135+
f"Options must be key/value pairs (e.g. num_digits=3): '{option}'"
138136
) from err
139137
return output_dict

kerngen/high_parser/types.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -148,17 +148,21 @@ class EmptyLine(BaseModel):
148148
class Context(BaseModel):
149149
"""Class representing a given context of the scheme"""
150150

151+
# required context params
151152
scheme: str
152153
poly_order: int # the N
153-
max_rns: int
154+
key_rns: int
155+
current_rns: int
154156
# optional vars for context
155-
key_rns: int | None
156157
num_digits: int | None
157158

159+
# calculated based on required params
160+
max_rns: int
161+
158162
@classmethod
159163
def from_string(cls, line: str):
160164
"""Construct context from a string"""
161-
scheme, poly_order, max_rns, *optionals = line.split()
165+
scheme, poly_order, key_rns, current_rns, *optionals = line.split()
162166
optional_dict = OptionsDictParser.parse(optionals)
163167
int_poly_order = int(poly_order)
164168
if (
@@ -170,15 +174,21 @@ def from_string(cls, line: str):
170174
f"Poly order `{int_poly_order}` must be power of two >= {MIN_POLY_SIZE} and < {MAX_POLY_SIZE}"
171175
)
172176

173-
int_max_rns = int(max_rns)
174-
int_key_rns = int_max_rns
175-
int_key_rns += optional_dict.pop("krns_delta")
177+
int_key_rns = int(key_rns)
178+
int_current_rns = int(current_rns)
179+
int_max_rns = int_key_rns - 1
180+
181+
if int_key_rns <= int_current_rns:
182+
raise ValueError(
183+
f"Current RNS must be less than Key RNS: current_rns={current_rns}, key_rns={key_rns}"
184+
)
176185

177186
return cls(
178187
scheme=scheme.upper(),
179188
poly_order=int_poly_order,
180-
max_rns=int_max_rns,
181189
key_rns=int_key_rns,
190+
current_rns=int_current_rns,
191+
max_rns=int_max_rns,
182192
**optional_dict,
183193
)
184194

kerngen/pisa_generators/basic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def partial_op(
352352
last_q: int,
353353
):
354354
""" "A helper function to perform partial operation, such as add/sub on last half (input1) to all of input0"""
355+
355356
return [
356357
op(
357358
context.label,
@@ -366,7 +367,7 @@ def partial_op(
366367
)
367368
for part, q, unit in it.product(
368369
range(polys.input_remaining_rns.parts),
369-
range(polys.input_remaining_rns.rns),
370+
range(polys.input_remaining_rns.start_rns, polys.input_remaining_rns.rns),
370371
range(context.units),
371372
)
372373
]

kerngen/pisa_generators/decomp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def to_pisa(self) -> list[PIsaOp]:
3333
r2 = Immediate(name="R2", rns=self.context.key_rns)
3434

3535
ls: list[pisa_op] = []
36-
for input_rns_index in range(self.input0.rns):
36+
for input_rns_index in range(self.input0.start_rns, self.input0.rns):
3737
ls.extend(
3838
pisa_op.Muli(
3939
self.context.label,

kerngen/pisa_generators/mod.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def to_pisa(self) -> list[PIsaOp]:
3939
"""Return the p-isa code to perform an mod switch down"""
4040
# Immediates
4141
last_q = self.input0.rns - 1
42+
self.input0.start_rns = (self.context.key_rns - 1) - self.context.current_rns
43+
4244
it = Immediate(name="it" + self.var_suffix)
4345
t = Immediate(name="t", rns=last_q)
4446
one, r2, iq = common_immediates(

kerngen/pisa_generators/relin.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def to_pisa(self) -> list[PIsaOp]:
3030
)
3131

3232
mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns)
33+
3334
mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk)
3435
mul_by_rlk_modded_down.rns = self.input0.rns
36+
3537
input_last_part, last_coeff, upto_last_coeffs = extract_last_part_polys(
3638
self.input0, self.context.key_rns
3739
)

kerngen/pisa_generators/rotate.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def to_pisa(self) -> list[PIsaOp]:
2828
supports number of digits equal to the RNS size"""
2929
self.output.parts = 2
3030
self.input0.parts = 2
31+
3132
relin_key = KeyPolys(
3233
"gk", parts=2, rns=self.context.key_rns, digits=self.input0.rns
3334
)
@@ -60,7 +61,7 @@ def to_pisa(self) -> list[PIsaOp]:
6061
Comment("Multiply by rotate key"),
6162
KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 1),
6263
Comment("Mod switch down to Q"),
63-
Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk),
64+
Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk, Mod.MOD_P),
6465
INTT(self.context, cd, start_input),
6566
NTT(self.context, cd, cd),
6667
Add(self.context, self.output, cd, first_part_rlk),

kerngen/tests/test_kerngen.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_op(kerngen_path, gen_op_data):
4747
def test_missing_context(kerngen_path):
4848
"""Test kerngen raises an exception when context is not the first line of
4949
input"""
50-
input_string = "ADD a b c\nCONTEXT BGV 16384 4\n"
50+
input_string = "ADD a b c\nCONTEXT BGV 16384 4 3\n"
5151
result = execute_process(
5252
[kerngen_path],
5353
data_in=input_string,
@@ -59,7 +59,7 @@ def test_missing_context(kerngen_path):
5959

6060
def test_multiple_contexts(kerngen_path):
6161
"""Test kerngen raises an exception when more than one context is given"""
62-
input_string = "CONTEXT BGV 16384 4\nData a 2\nCONTEXT BGV 16384 4\n"
62+
input_string = "CONTEXT BGV 16384 4 2\nData a 2\nCONTEXT BGV 16384 4 2\n"
6363
result = execute_process(
6464
[kerngen_path],
6565
data_in=input_string,
@@ -71,22 +71,22 @@ def test_multiple_contexts(kerngen_path):
7171

7272
def test_context_options_without_key(kerngen_path):
7373
"""Test kerngen raises an exception when more than one context is given"""
74-
input_string = "CONTEXT BGV 16384 4 1\nData a 2\n"
74+
input_string = "CONTEXT BGV 16384 3 2 1\nData a 2\n"
7575
result = execute_process(
7676
[kerngen_path],
7777
data_in=input_string,
7878
)
7979
assert not result.stdout
8080
assert (
81-
"ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '1'"
81+
"ValueError: Options must be key/value pairs (e.g. num_digits=3): '1'"
8282
in result.stderr
8383
)
8484
assert result.returncode != 0
8585

8686

8787
def test_context_unsupported_options_variable(kerngen_path):
8888
"""Test kerngen raises an exception when more than one context is given"""
89-
input_string = "CONTEXT BGV 16384 4 test=3\nData a 2\n"
89+
input_string = "CONTEXT BGV 16384 3 2 test=3\nData a 2\n"
9090
result = execute_process(
9191
[kerngen_path],
9292
data_in=input_string,
@@ -99,14 +99,14 @@ def test_context_unsupported_options_variable(kerngen_path):
9999
@pytest.mark.parametrize("invalid", [-1, 256, 0.1, "str"])
100100
def test_context_option_invalid_values(kerngen_path, invalid):
101101
"""Test kerngen raises an exception if value is out of range for correct key"""
102-
input_string = f"CONTEXT BGV 16384 4 krns_delta={invalid}\nData a 2\n"
102+
input_string = f"CONTEXT BGV 16384 3 2 num_digits={invalid}\nData a 2\n"
103103
result = execute_process(
104104
[kerngen_path],
105105
data_in=input_string,
106106
)
107107
assert not result.stdout
108108
assert (
109-
f"ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): 'krns_delta={invalid}'"
109+
f"ValueError: Options must be key/value pairs (e.g. num_digits=3): 'num_digits={invalid}'"
110110
in result.stderr
111111
)
112112
assert result.returncode != 0
@@ -115,7 +115,7 @@ def test_context_option_invalid_values(kerngen_path, invalid):
115115
def test_unrecognised_opname(kerngen_path):
116116
"""Test kerngen raises an exception when receiving an unrecognised
117117
opname"""
118-
input_string = "CONTEXT BGV 16384 4\nOPERATION a b c\n"
118+
input_string = "CONTEXT BGV 16384 3 2\nOPERATION a b c\n"
119119
result = execute_process(
120120
[kerngen_path],
121121
data_in=input_string,
@@ -129,7 +129,7 @@ def test_unrecognised_opname(kerngen_path):
129129

130130
def test_invalid_scheme(kerngen_path):
131131
"""Test kerngen raises an exception when receiving an invalid scheme"""
132-
input_string = "CONTEXT SCHEME 16384 4\nADD a b c\n"
132+
input_string = "CONTEXT SCHEME 16384 4 3\nADD a b c\n"
133133
result = execute_process(
134134
[kerngen_path],
135135
data_in=input_string,
@@ -142,7 +142,7 @@ def test_invalid_scheme(kerngen_path):
142142
@pytest.mark.parametrize("invalid_poly", [16000, 2**12, 2**13, 2**18])
143143
def test_invalid_poly_order(kerngen_path, invalid_poly):
144144
"""Poly order should be powers of two >= 2^14 and <= 2^17"""
145-
input_string = "CONTEXT BGV " + str(invalid_poly) + " 4\nADD a b c\n"
145+
input_string = "CONTEXT BGV " + str(invalid_poly) + " 4 3\nADD a b c\n"
146146
result = execute_process(
147147
[kerngen_path],
148148
data_in=input_string,
@@ -168,8 +168,12 @@ def test_parse_results_multiple_context():
168168
with pytest.raises(LookupError) as e:
169169
parse_results = ParseResults(
170170
[
171-
Context(scheme="BGV", poly_order=16384, max_rns=1),
172-
Context(scheme="CKKS", poly_order=16384, max_rns=1),
171+
Context(
172+
scheme="BGV", poly_order=16384, key_rns=2, current_rns=1, max_rns=1
173+
),
174+
Context(
175+
scheme="CKKS", poly_order=16384, key_rns=2, current_rns=1, max_rns=1
176+
),
173177
],
174178
{},
175179
)
@@ -181,7 +185,7 @@ def test_parse_results_multiple_context():
181185
def fixture_gen_op_data(request):
182186
"""Given an op name, return both the input and expected output strings"""
183187
in_lines = (
184-
"CONTEXT BGV 16384 4",
188+
"CONTEXT BGV 16384 4 3",
185189
"Data a 2",
186190
"Data b 2",
187191
"Data c 2",

0 commit comments

Comments
 (0)