Skip to content

Commit 1bc6680

Browse files
update data_formats to new format/pylint/etc
1 parent 2d6a17c commit 1bc6680

File tree

8 files changed

+156
-271
lines changed

8 files changed

+156
-271
lines changed

p-isa_tools/data_formats/proto/heracles/data.proto

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ message FHEContext {
4343
repeated uint32 psi = 6; // 2n-th root of Z_{q_i}, implies elsewhere used n-th root omega = psi^2
4444
uint32 q_size = 7; // n(Q), for SEAL it is always ( key_rns_num - 1 ), OpenFHE (cc->GetElementParams()->GetParams().size())
4545
uint32 alpha = 8; // ceil(q_size/dnum)
46-
oneof scheme_specific {
47-
CKKSSpecific ckks_info = 9;
48-
BGVSpecific bgv_info = 10;
46+
oneof scheme_specific {
47+
CKKSSpecific ckks_info = 9;
48+
BGVSpecific bgv_info = 10;
4949
}
5050
}
5151

@@ -76,10 +76,9 @@ message TestVector {
7676

7777
// TODO: merge this with above
7878
message Data {
79-
DCRTPoly dcrtpoly = 1; // this is only used for v2
79+
DCRTPoly dcrtpoly = 1; // this is only used for v2
8080
}
8181

82-
8382
// MAIN INTERFACE FOR CONSUMERs &
8483
// ROOT TYPES FOR SERIALIZATIONS
8584
//=================================
@@ -129,17 +128,16 @@ message BGVPlaintextSpecific {
129128
}
130129

131130
message CKKSSpecific {
132-
Keys keys = 1;
133-
uint32 composite_degree = 2; // BASE_NUM_LEVELS_TO_DROP
131+
Keys keys = 1;
132+
uint32 composite_degree = 2; // BASE_NUM_LEVELS_TO_DROP
134133
// Scaling factors
135-
repeated double scaling_factor_real = 3; // size: q_size (CryptoParametersRNS->GetScalingFactorReal())
136-
repeated double scaling_factor_real_big = 4; // size: q_size - 1(CryptoParametersRNS->GetScalingFactorRealBig())
134+
repeated double scaling_factor_real = 3; // size: q_size (CryptoParametersRNS->GetScalingFactorReal())
135+
repeated double scaling_factor_real_big = 4; // size: q_size - 1(CryptoParametersRNS->GetScalingFactorRealBig())
137136
map<string, uint32> metadata_extra = 5;
138137
}
139138

140-
141139
message Keys {
142-
KeySwitch relin_key = 1;
140+
KeySwitch relin_key = 1;
143141
map<uint32, KeySwitch> rotation_keys = 2;
144142
}
145143

@@ -169,7 +167,7 @@ message Plaintext {
169167

170168
// DCRTPoly form
171169
message DCRTPoly {
172-
repeated Polynomial polys = 1; // size = order
170+
repeated Polynomial polys = 1; // size = order
173171
bool in_ntt_form = 2;
174172
}
175173

@@ -182,8 +180,6 @@ message Polynomial {
182180
message RNSPolynomial {
183181
repeated uint32 coeffs = 1; // repeated a power-of-two times with power-of-two
184182
uint32 modulus = 2; // need modulus in case curr_rns < max_rns, as it will use part of Q and all of P
185-
186-
187183
}
188184

189185
message HECRNSPolynomial {

p-isa_tools/data_formats/proto/heracles/fhe_trace.proto

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,56 +18,56 @@ message Trace {
1818
// Sequence of HE instructions
1919
repeated Instruction instructions = 1;
2020
heracles.common.Scheme scheme = 2;
21-
uint32 N = 3; // poly modulus degree
21+
uint32 N = 3; // poly modulus degree
2222
uint32 key_rns_num = 4;
23-
uint32 q_size = 5; // n(Q)
24-
uint32 dnum = 6; // digit size
25-
uint32 alpha = 7; // ceil(n(Q)/dnum) Note: key_rns_num=n(Q) + n(P)
23+
uint32 q_size = 5; // n(Q)
24+
uint32 dnum = 6; // digit size
25+
uint32 alpha = 7; // ceil(n(Q)/dnum) Note: key_rns_num=n(Q) + n(P)
2626
}
2727

2828
message Instruction {
29-
string op = 1;
30-
uint32 plaintext_index = 2; // which plaintext algebra used, can be ignored for CKKS. Used as index into `plaintext_specific` field of `heracles.data.BGVSpecific` object inside `heracles.data.FHEContext`.
31-
Operands args = 3; // inputs/outputs and additional params
32-
string evalop_name = 4; // (OpenFHE specific) Evaluator level call tracking, helps identifying what eval op invoked atomic ops
29+
string op = 1;
30+
uint32 plaintext_index = 2; // which plaintext algebra used, can be ignored for CKKS. Used as index into `plaintext_specific` field of `heracles.data.BGVSpecific` object inside `heracles.data.FHEContext`.
31+
Operands args = 3; // inputs/outputs and additional params
32+
string evalop_name = 4; // (OpenFHE specific) Evaluator level call tracking, helps identifying what eval op invoked atomic ops
3333
}
3434

3535
message Operands {
36-
repeated OperandObject dests = 1;
37-
repeated OperandObject srcs = 2;
38-
map<string, Parameter> params = 3;
36+
repeated OperandObject dests = 1;
37+
repeated OperandObject srcs = 2;
38+
map<string, Parameter> params = 3;
3939
}
4040

4141
message Parameter {
42-
string value = 1;
43-
ValueType type = 2;
42+
string value = 1;
43+
ValueType type = 2;
4444
}
4545
enum ValueType {
46-
UINT32 = 0 [
46+
UINT32 = 0 [
4747
(valuetype_name) = "UINT32"
4848
];
49-
UINT64 = 1 [
49+
UINT64 = 1 [
5050
(valuetype_name) = "UINT64"
5151
];
52-
INT32 = 2 [
52+
INT32 = 2 [
5353
(valuetype_name) = "INT32"
5454
];
55-
INT64 = 3 [
55+
INT64 = 3 [
5656
(valuetype_name) = "INT64"
5757
];
58-
FLOAT = 4 [
58+
FLOAT = 4 [
5959
(valuetype_name) = "FLOAT"
6060
];
61-
DOUBLE = 5 [
61+
DOUBLE = 5 [
6262
(valuetype_name) = "DOUBLE"
6363
];
64-
STRING = 6 [
64+
STRING = 6 [
6565
(valuetype_name) = "STRING"
6666
];
6767
}
6868

6969
message OperandObject {
7070
string symbol_name = 1;
71-
uint32 num_rns = 2; // size = curr_rns of dcrtpoly
72-
uint32 order = 3; // typically 2 for ct/pt (can be 3), single DCRTPoly will always be 1
71+
uint32 num_rns = 2; // size = curr_rns of dcrtpoly
72+
uint32 order = 3; // typically 2 for ct/pt (can be 3), single DCRTPoly will always be 1
7373
}

p-isa_tools/data_formats/python/heracles/data/io.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,27 @@
44

55
# TODO: create also C++ variants of below; given how simple and stable these functions should be just in replicated form, not shared code
66

7-
import heracles.proto.data_pb2 as hpd
8-
import heracles.proto.common_pb2 as hpc
9-
import heracles.proto.fhe_trace_pb2 as hpf
10-
import google.protobuf.json_format as gpj
117
import json
12-
import heracles.data.transform as hdt
13-
from google.protobuf.json_format import MessageToJson, MessageToDict
14-
from glob import glob
15-
import re
168
import sys
179

10+
import heracles.proto.data_pb2 as hpd
11+
from google.protobuf.json_format import MessageToDict
12+
1813
# load & store functions
1914
# ===============================
2015

2116

2217
def parse_manifest(filename: str) -> dict:
23-
manifest = dict()
24-
with open(filename, "r") as fp:
18+
manifest: dict = {}
19+
with open(filename) as fp:
2520
cur_field = None
2621
found_first_field = False
2722
for linenum, cur_line in enumerate(fp):
2823
cur_line = cur_line.rstrip()
2924
if cur_line.startswith("[") and cur_line.endswith("]"):
3025
cur_field = cur_line[1:-1]
3126
found_first_field = True
32-
manifest[cur_field] = dict()
27+
manifest[cur_field] = {}
3328
continue
3429

3530
if not found_first_field:
@@ -60,7 +55,7 @@ def generate_manifest(filename: str, manifest: dict):
6055
# re-check
6156
def store_hec_context_json(filename: str, context: hpd.FHEContext):
6257
print(
63-
f"Warning: Dumping FHE Context data trace to json can take a long time",
58+
"Warning: Dumping FHE Context data trace to json can take a long time",
6459
file=sys.stderr,
6560
)
6661
with open(filename, "w") as fp:
@@ -70,7 +65,7 @@ def store_hec_context_json(filename: str, context: hpd.FHEContext):
7065
# re-check
7166
def store_testvector_json(filename: str, testvector: hpd.TestVector):
7267
print(
73-
f"Warning: Dumping TestVector data trace to json can take a long time",
68+
"Warning: Dumping TestVector data trace to json can take a long time",
7469
file=sys.stderr,
7570
)
7671
with open(filename, "w") as fp:
@@ -95,15 +90,13 @@ def load_hec_context_from_manifest(manifest: dict) -> hpd.FHEContext:
9590

9691

9792
def store_hec_context(filename: str, context_pb: hpd.FHEContext) -> dict:
98-
hec_context_manifest = {"context": dict()}
93+
hec_context_manifest: dict = {"context": {}}
9994
tmp_context = hpd.FHEContext()
10095
tmp_context.CopyFrom(context_pb)
10196

10297
if tmp_context.ByteSize() > 1 << 30:
103-
hec_context_manifest["rotation_keys"] = dict()
104-
for gkct, (ge, gk_pb) in enumerate(
105-
tmp_context.ckks_info.keys.rotation_keys.items()
106-
):
98+
hec_context_manifest["rotation_keys"] = {}
99+
for gkct, (ge, gk_pb) in enumerate(tmp_context.ckks_info.keys.rotation_keys.items()):
107100
parts_fn = f"{filename}_hec_context_part_{gkct + 1}"
108101
hec_context_manifest["rotation_keys"][ge] = parts_fn
109102
with open(parts_fn, "wb") as fp:
@@ -137,7 +130,7 @@ def load_testvector_from_manifest(manifest: dict) -> hpd.TestVector:
137130

138131

139132
def store_testvector(filename: str, testvector_pb: hpd.TestVector) -> dict:
140-
testvector_manifest = {"testvector": dict()}
133+
testvector_manifest: dict = {"testvector": {}}
141134
if testvector_pb.ByteSize() > 1 << 30:
142135
for tvct, (sym, data_pb) in enumerate(testvector_pb.sym_data_map.items()):
143136
parts_fn = f"{filename}_testvector_part_{tvct}"
@@ -171,9 +164,7 @@ def load_data_trace(filename: str) -> tuple[hpd.FHEContext, hpd.TestVector]:
171164
)
172165

173166

174-
def store_data_trace(
175-
filename: str, context_pb: hpd.FHEContext, testvector_pb: hpd.TestVector
176-
):
167+
def store_data_trace(filename: str, context_pb: hpd.FHEContext, testvector_pb: hpd.TestVector):
177168
generate_manifest(
178169
filename,
179170
{

0 commit comments

Comments
 (0)