Skip to content

Context Parameters Refactor - Key/Current RNS #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions kerngen/high_parser/options_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ def __init__(self, int_min: int, int_max: int, default: int | None):
class OptionsDictFactory(ABC):
"""Abstract class that creates OptionsDict objects"""

MAX_KRNS_DELTA = 128
MAX_DIGIT = 3
MIN_KRNS_DELTA = MIN_DIGIT = 0
MIN_DIGIT = 0
options = {
"krns_delta": OptionsIntBounds(MIN_KRNS_DELTA, MAX_KRNS_DELTA, 0),
"num_digits": OptionsIntBounds(MIN_DIGIT, MAX_DIGIT, None),
}

Expand Down Expand Up @@ -134,6 +132,6 @@ def parse(options: list[str]):
).op_value
except ValueError as err:
raise ValueError(
f"Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '{option}'"
f"Options must be key/value pairs (e.g. num_digits=3): '{option}'"
) from err
return output_dict
24 changes: 17 additions & 7 deletions kerngen/high_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,21 @@ class EmptyLine(BaseModel):
class Context(BaseModel):
"""Class representing a given context of the scheme"""

# required context params
scheme: str
poly_order: int # the N
max_rns: int
key_rns: int
current_rns: int
# optional vars for context
key_rns: int | None
num_digits: int | None

# calculated based on required params
max_rns: int

@classmethod
def from_string(cls, line: str):
"""Construct context from a string"""
scheme, poly_order, max_rns, *optionals = line.split()
scheme, poly_order, key_rns, current_rns, *optionals = line.split()
optional_dict = OptionsDictParser.parse(optionals)
int_poly_order = int(poly_order)
if (
Expand All @@ -170,15 +174,21 @@ def from_string(cls, line: str):
f"Poly order `{int_poly_order}` must be power of two >= {MIN_POLY_SIZE} and < {MAX_POLY_SIZE}"
)

int_max_rns = int(max_rns)
int_key_rns = int_max_rns
int_key_rns += optional_dict.pop("krns_delta")
int_key_rns = int(key_rns)
int_current_rns = int(current_rns)
int_max_rns = int_key_rns - 1

if int_key_rns <= int_current_rns:
raise ValueError(
f"Current RNS must be less than Key RNS: current_rns={current_rns}, key_rns={key_rns}"
)

return cls(
scheme=scheme.upper(),
poly_order=int_poly_order,
max_rns=int_max_rns,
key_rns=int_key_rns,
current_rns=int_current_rns,
max_rns=int_max_rns,
**optional_dict,
)

Expand Down
3 changes: 2 additions & 1 deletion kerngen/pisa_generators/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def partial_op(
last_q: int,
):
""" "A helper function to perform partial operation, such as add/sub on last half (input1) to all of input0"""

return [
op(
context.label,
Expand All @@ -366,7 +367,7 @@ def partial_op(
)
for part, q, unit in it.product(
range(polys.input_remaining_rns.parts),
range(polys.input_remaining_rns.rns),
range(polys.input_remaining_rns.start_rns, polys.input_remaining_rns.rns),
range(context.units),
)
]
Expand Down
2 changes: 1 addition & 1 deletion kerngen/pisa_generators/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def to_pisa(self) -> list[PIsaOp]:
r2 = Immediate(name="R2", rns=self.context.key_rns)

ls: list[pisa_op] = []
for input_rns_index in range(self.input0.rns):
for input_rns_index in range(self.input0.start_rns, self.input0.rns):
ls.extend(
pisa_op.Muli(
self.context.label,
Expand Down
2 changes: 2 additions & 0 deletions kerngen/pisa_generators/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform an mod switch down"""
# Immediates
last_q = self.input0.rns - 1
self.input0.start_rns = (self.context.key_rns - 1) - self.context.current_rns

it = Immediate(name="it" + self.var_suffix)
t = Immediate(name="t", rns=last_q)
one, r2, iq = common_immediates(
Expand Down
2 changes: 2 additions & 0 deletions kerngen/pisa_generators/relin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def to_pisa(self) -> list[PIsaOp]:
)

mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns)

mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk)
mul_by_rlk_modded_down.rns = self.input0.rns

input_last_part, last_coeff, upto_last_coeffs = extract_last_part_polys(
self.input0, self.context.key_rns
)
Expand Down
3 changes: 2 additions & 1 deletion kerngen/pisa_generators/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def to_pisa(self) -> list[PIsaOp]:
supports number of digits equal to the RNS size"""
self.output.parts = 2
self.input0.parts = 2

relin_key = KeyPolys(
"gk", parts=2, rns=self.context.key_rns, digits=self.input0.rns
)
Expand Down Expand Up @@ -60,7 +61,7 @@ def to_pisa(self) -> list[PIsaOp]:
Comment("Multiply by rotate key"),
KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 1),
Comment("Mod switch down to Q"),
Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk),
Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk, Mod.MOD_P),
INTT(self.context, cd, start_input),
NTT(self.context, cd, cd),
Add(self.context, self.output, cd, first_part_rlk),
Expand Down
30 changes: 17 additions & 13 deletions kerngen/tests/test_kerngen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_op(kerngen_path, gen_op_data):
def test_missing_context(kerngen_path):
"""Test kerngen raises an exception when context is not the first line of
input"""
input_string = "ADD a b c\nCONTEXT BGV 16384 4\n"
input_string = "ADD a b c\nCONTEXT BGV 16384 4 3\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
Expand All @@ -59,7 +59,7 @@ def test_missing_context(kerngen_path):

def test_multiple_contexts(kerngen_path):
"""Test kerngen raises an exception when more than one context is given"""
input_string = "CONTEXT BGV 16384 4\nData a 2\nCONTEXT BGV 16384 4\n"
input_string = "CONTEXT BGV 16384 4 2\nData a 2\nCONTEXT BGV 16384 4 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
Expand All @@ -71,22 +71,22 @@ def test_multiple_contexts(kerngen_path):

def test_context_options_without_key(kerngen_path):
"""Test kerngen raises an exception when more than one context is given"""
input_string = "CONTEXT BGV 16384 4 1\nData a 2\n"
input_string = "CONTEXT BGV 16384 3 2 1\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
)
assert not result.stdout
assert (
"ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '1'"
"ValueError: Options must be key/value pairs (e.g. num_digits=3): '1'"
in result.stderr
)
assert result.returncode != 0


def test_context_unsupported_options_variable(kerngen_path):
"""Test kerngen raises an exception when more than one context is given"""
input_string = "CONTEXT BGV 16384 4 test=3\nData a 2\n"
input_string = "CONTEXT BGV 16384 3 2 test=3\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
Expand All @@ -99,14 +99,14 @@ def test_context_unsupported_options_variable(kerngen_path):
@pytest.mark.parametrize("invalid", [-1, 256, 0.1, "str"])
def test_context_option_invalid_values(kerngen_path, invalid):
"""Test kerngen raises an exception if value is out of range for correct key"""
input_string = f"CONTEXT BGV 16384 4 krns_delta={invalid}\nData a 2\n"
input_string = f"CONTEXT BGV 16384 3 2 num_digits={invalid}\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
)
assert not result.stdout
assert (
f"ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): 'krns_delta={invalid}'"
f"ValueError: Options must be key/value pairs (e.g. num_digits=3): 'num_digits={invalid}'"
in result.stderr
)
assert result.returncode != 0
Expand All @@ -115,7 +115,7 @@ def test_context_option_invalid_values(kerngen_path, invalid):
def test_unrecognised_opname(kerngen_path):
"""Test kerngen raises an exception when receiving an unrecognised
opname"""
input_string = "CONTEXT BGV 16384 4\nOPERATION a b c\n"
input_string = "CONTEXT BGV 16384 3 2\nOPERATION a b c\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
Expand All @@ -129,7 +129,7 @@ def test_unrecognised_opname(kerngen_path):

def test_invalid_scheme(kerngen_path):
"""Test kerngen raises an exception when receiving an invalid scheme"""
input_string = "CONTEXT SCHEME 16384 4\nADD a b c\n"
input_string = "CONTEXT SCHEME 16384 4 3\nADD a b c\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
Expand All @@ -142,7 +142,7 @@ def test_invalid_scheme(kerngen_path):
@pytest.mark.parametrize("invalid_poly", [16000, 2**12, 2**13, 2**18])
def test_invalid_poly_order(kerngen_path, invalid_poly):
"""Poly order should be powers of two >= 2^14 and <= 2^17"""
input_string = "CONTEXT BGV " + str(invalid_poly) + " 4\nADD a b c\n"
input_string = "CONTEXT BGV " + str(invalid_poly) + " 4 3\nADD a b c\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
Expand All @@ -168,8 +168,12 @@ def test_parse_results_multiple_context():
with pytest.raises(LookupError) as e:
parse_results = ParseResults(
[
Context(scheme="BGV", poly_order=16384, max_rns=1),
Context(scheme="CKKS", poly_order=16384, max_rns=1),
Context(
scheme="BGV", poly_order=16384, key_rns=2, current_rns=1, max_rns=1
),
Context(
scheme="CKKS", poly_order=16384, key_rns=2, current_rns=1, max_rns=1
),
],
{},
)
Expand All @@ -181,7 +185,7 @@ def test_parse_results_multiple_context():
def fixture_gen_op_data(request):
"""Given an op name, return both the input and expected output strings"""
in_lines = (
"CONTEXT BGV 16384 4",
"CONTEXT BGV 16384 4 3",
"Data a 2",
"Data b 2",
"Data c 2",
Expand Down