Skip to content

Commit 34ab42a

Browse files
authored
Fix option generator (#53)
Option values were being dumped as `MessageLiteral` rather than parsing the AST elements properly. This is now fixed. Complex optional types are allowed. Fixes #51
1 parent 1ca76d6 commit 34ab42a

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

proto_schema_parser/generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def generate(self, file: ast.File) -> str:
4343
lines.append(f"option {element.name} = true;")
4444
else:
4545
lines.append(f"option {element.name} = false;")
46+
elif isinstance(element.value, ast.MessageLiteral):
47+
value = self._generate_option_value(element.value, 0)
48+
lines.append(f"option {element.name} = {value};")
4649
else:
4750
lines.append(f'option {element.name} = "{element.value}";')
4851
elif isinstance(element, ast.Message):

tests/test_generator.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,3 +1218,205 @@ def test_generate_service_with_additional_bindings():
12181218
)
12191219

12201220
assert result == expected
1221+
1222+
1223+
def test_generate_option_with_complex_nested_message_literal_swagger():
1224+
"""Test that a complex message literal with nested security definitions is generated correctly."""
1225+
file = ast.File(
1226+
file_elements=[
1227+
ast.Option(
1228+
name="(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger)",
1229+
value=ast.MessageLiteral(
1230+
fields=[
1231+
ast.MessageLiteralField(
1232+
name="security_definitions",
1233+
value=ast.MessageLiteral(
1234+
fields=[
1235+
ast.MessageLiteralField(
1236+
name="security",
1237+
value=ast.MessageLiteral(
1238+
fields=[
1239+
ast.MessageLiteralField(
1240+
name="key", value="ApiKey"
1241+
),
1242+
ast.MessageLiteralField(
1243+
name="value",
1244+
value=ast.MessageLiteral(
1245+
fields=[
1246+
ast.MessageLiteralField(
1247+
name="type",
1248+
value=ast.Identifier(
1249+
name="TYPE_API_KEY"
1250+
),
1251+
),
1252+
ast.MessageLiteralField(
1253+
name="in",
1254+
value=ast.Identifier(
1255+
name="IN_HEADER"
1256+
),
1257+
),
1258+
ast.MessageLiteralField(
1259+
name="name",
1260+
value="Authorization",
1261+
),
1262+
]
1263+
),
1264+
),
1265+
]
1266+
),
1267+
)
1268+
]
1269+
),
1270+
),
1271+
ast.MessageLiteralField(
1272+
name="security",
1273+
value=ast.MessageLiteral(
1274+
fields=[
1275+
ast.MessageLiteralField(
1276+
name="security_requirement",
1277+
value=ast.MessageLiteral(
1278+
fields=[
1279+
ast.MessageLiteralField(
1280+
name="key", value="ApiKey"
1281+
),
1282+
ast.MessageLiteralField(
1283+
name="value",
1284+
value=ast.MessageLiteral(fields=[]),
1285+
),
1286+
]
1287+
),
1288+
)
1289+
]
1290+
),
1291+
),
1292+
]
1293+
),
1294+
)
1295+
]
1296+
)
1297+
1298+
result = Generator().generate(file)
1299+
expected = """option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
1300+
security_definitions: {
1301+
security: {
1302+
key: "ApiKey",
1303+
value: {
1304+
type: TYPE_API_KEY,
1305+
in: IN_HEADER,
1306+
name: "Authorization"
1307+
}
1308+
}
1309+
},
1310+
security: {
1311+
security_requirement: {
1312+
key: "ApiKey",
1313+
value: {}
1314+
}
1315+
}
1316+
};"""
1317+
1318+
assert result == expected
1319+
1320+
1321+
def test_generate_multiple_options_with_complex_message_literals():
1322+
"""Test that multiple options with complex message literals are generated correctly."""
1323+
file = ast.File(
1324+
file_elements=[
1325+
ast.Option(name="go_package", value="go.etcd.io/etcd/api/v3/etcdserverpb"),
1326+
ast.Option(name="(gogoproto.marshaler_all)", value=True),
1327+
ast.Option(name="(gogoproto.unmarshaler_all)", value=True),
1328+
ast.Option(
1329+
name="(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger)",
1330+
value=ast.MessageLiteral(
1331+
fields=[
1332+
ast.MessageLiteralField(
1333+
name="security_definitions",
1334+
value=ast.MessageLiteral(
1335+
fields=[
1336+
ast.MessageLiteralField(
1337+
name="security",
1338+
value=ast.MessageLiteral(
1339+
fields=[
1340+
ast.MessageLiteralField(
1341+
name="key", value="ApiKey"
1342+
),
1343+
ast.MessageLiteralField(
1344+
name="value",
1345+
value=ast.MessageLiteral(
1346+
fields=[
1347+
ast.MessageLiteralField(
1348+
name="type",
1349+
value=ast.Identifier(
1350+
name="TYPE_API_KEY"
1351+
),
1352+
),
1353+
ast.MessageLiteralField(
1354+
name="in",
1355+
value=ast.Identifier(
1356+
name="IN_HEADER"
1357+
),
1358+
),
1359+
ast.MessageLiteralField(
1360+
name="name",
1361+
value="Authorization",
1362+
),
1363+
]
1364+
),
1365+
),
1366+
]
1367+
),
1368+
)
1369+
]
1370+
),
1371+
),
1372+
ast.MessageLiteralField(
1373+
name="security",
1374+
value=ast.MessageLiteral(
1375+
fields=[
1376+
ast.MessageLiteralField(
1377+
name="security_requirement",
1378+
value=ast.MessageLiteral(
1379+
fields=[
1380+
ast.MessageLiteralField(
1381+
name="key", value="ApiKey"
1382+
),
1383+
ast.MessageLiteralField(
1384+
name="value",
1385+
value=ast.MessageLiteral(fields=[]),
1386+
),
1387+
]
1388+
),
1389+
)
1390+
]
1391+
),
1392+
),
1393+
]
1394+
),
1395+
),
1396+
]
1397+
)
1398+
1399+
result = Generator().generate(file)
1400+
expected = """option go_package = "go.etcd.io/etcd/api/v3/etcdserverpb";
1401+
option (gogoproto.marshaler_all) = true;
1402+
option (gogoproto.unmarshaler_all) = true;
1403+
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
1404+
security_definitions: {
1405+
security: {
1406+
key: "ApiKey",
1407+
value: {
1408+
type: TYPE_API_KEY,
1409+
in: IN_HEADER,
1410+
name: "Authorization"
1411+
}
1412+
}
1413+
},
1414+
security: {
1415+
security_requirement: {
1416+
key: "ApiKey",
1417+
value: {}
1418+
}
1419+
}
1420+
};"""
1421+
1422+
assert result == expected

0 commit comments

Comments
 (0)