Skip to content

Commit c36b0b4

Browse files
committed
Add types to the graph.
This adds a fourth node type, and a fourth edge flow, both called "type". The idea is to represent types as first-class elements in the graph representation. This allows greater compositionality by breaking up composite types into subcomponents, and decreases the required vocabulary size required to achieve a given coverage. Background ---------- Currently, type information is stored in the "text" field of nodes for constants and variables, e.g.: node { type: VARIABLE text: "i8" } There are two issues with this: * Composite types end up with long textual representations, e.g. "struct foo { i32 a; i32 b; ... }". Since there is an unbounded number of possible structs, this prevents 100% vocabulary coverage on any IR with structs (or other composite types). * In the future, we will want to encode different information on data nodes, such as embedding literal values. Moving the type information out of the data node "frees up" space for something else. Overview -------- This changes the representation to represent types as first-class elements in the graph. A "type" node represents a type using its "text" field, and a new "type" edge connects this type to variables or constants of that type, e.g.: node { type: VARIABLE text: "var" } node { type: TYPE text: "i8" } edge { flow: TYPE source: 1 } Composite types --------------- Types may be composed by connecting of many type nodes using type edges. This allows you to break down complex types into a graph of primitive parts. The meaning of composite types will depend on the type of IR, the remainder describes the process for LLVM-IR. Pointer types ------------- A pointer is a composite of two types: [pointer] <- [pointed-type] For example: int32_t* instance; Would be represented as: node { type: TYPE text: "i32" } node { type: TYPE text: "*" } node { type: VARIABLE text: "var" } edge { text: TYPE target: 1 } edge { text: TYPE source: 1 target: 2 } Where variables/constants of this type receive an incoming type edge from the [pointer] node, which in turn receives an incoming type edge from the [pointed-type] node. One [pointer] node is generated for each unique pointer type. If a graph contains multiple pointer types, there will be multiple [pointer] nodes. Struct types ------------ A struct is a compsite type where each member is a node type which points to the parent node. Variable/constant instances of a struct receive an incoming type edge from the root struct node. Note that the graph of type nodes representing a composite struct type may be cyclical, since a struct can contain a pointer of the same type (think of a binary tree implementation). The type edges from member nodes to the parent struct are positional. The position indicates the element number. E.g. for a struct with three elements, the incoming type edges to the struct node will have positions 0, 1, and 2. This example struct: struct s { int8_t a; int8_t b; struct s* c; } struct s instance; Would be represented as: node { type: TYPE text: "struct" } node { type: TYPE text: "i8" } node { type: TYPE text: "*" } node { type: VARIABLE text: "var" } edge { flow: TYPE target: 1 } edge { flow: TYPE target: 1 position: 1 } edge { flow: TYPE target: 2 position: 2 } edge { flow: TYPE source: 2 } edge { flow: TYPE target: 3 } Array Types ----------- An array is a composite type [array] <- [element-type]. For example, the array: int a[10]; Would be represented as: node { type: TYPE text: "i32" } node { type: TYPE text: "[]" } node { type: VARIABLE text: "var" } edge { flow: TYPE target: 1 } edge { flow: TYPE source: 1 target: 2 } github.com//issues/82 Signed-off-by: format 2020.06.15 <github.com/ChrisCummins/format>
1 parent 3e24ddc commit c36b0b4

10 files changed

+202
-19
lines changed

programl/graph/format/graphviz_converter.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ labm8::Status SerializeGraphVizToString(const ProgramGraph& graph,
134134

135135
// Determine the subgraph to add this node to.
136136
boost::subgraph<GraphvizGraph>* dst = &external;
137-
if (i && node.type() != Node::CONSTANT) {
137+
if (i && (node.type() == Node::INSTRUCTION || node.type() == Node::VARIABLE)) {
138138
dst = &functionGraphs[node.function()].get();
139139
}
140140

@@ -192,29 +192,33 @@ labm8::Status SerializeGraphVizToString(const ProgramGraph& graph,
192192
}
193193
labm8::TruncateWithEllipsis(text, kMaximumLabelLen);
194194
attributes["label"] = text;
195+
attributes["style"] = "filled";
195196

196197
// Set the node shape.
197198
switch (node.type()) {
198199
case Node::INSTRUCTION:
199200
attributes["shape"] = "box";
200-
attributes["style"] = "filled";
201201
attributes["fillcolor"] = "#3c78d8";
202202
attributes["fontcolor"] = "#ffffff";
203203
break;
204204
case Node::VARIABLE:
205205
attributes["shape"] = "ellipse";
206-
attributes["style"] = "filled";
207206
attributes["fillcolor"] = "#f4cccc";
208207
attributes["color"] = "#990000";
209208
attributes["fontcolor"] = "#990000";
210209
break;
211210
case Node::CONSTANT:
212-
attributes["shape"] = "diamond";
213-
attributes["style"] = "filled";
211+
attributes["shape"] = "octagon";
214212
attributes["fillcolor"] = "#e99c9c";
215213
attributes["color"] = "#990000";
216214
attributes["fontcolor"] = "#990000";
217215
break;
216+
case Node::TYPE:
217+
attributes["shape"] = "diamond";
218+
attributes["fillcolor"] = "#cccccc";
219+
attributes["color"] = "#cccccc";
220+
attributes["fontcolor"] = "#222222";
221+
break;
218222
}
219223
}
220224

@@ -242,15 +246,21 @@ labm8::Status SerializeGraphVizToString(const ProgramGraph& graph,
242246
attributes["color"] = "#65ae4d";
243247
attributes["weight"] = "1";
244248
break;
249+
case Edge::TYPE:
250+
attributes["color"] = "#aaaaaa";
251+
attributes["weight"] = "1";
252+
attributes["penwidth"] = "1.5";
253+
break;
245254
}
246255

247256
// Set the edge label.
248257
if (edge.position()) {
249258
// Position labels for control edge are drawn close to the originating
250-
// instruction. For data edges, they are drawn closer to the consuming
251-
// instruction.
259+
// instruction. For control edges, they are drawn close to the branching
260+
// instruction. For data and type edges, they are drawn close to the
261+
// consuming node.
252262
const string label =
253-
edge.flow() == Edge::DATA ? "headlabel" : "taillabel";
263+
edge.flow() == Edge::CONTROL ? "taillabel" : "headlabel";
254264
attributes[label] = std::to_string(edge.position());
255265
attributes["labelfontcolor"] = attributes["color"];
256266
}

programl/graph/program_graph_builder.cc

+24
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ Node* ProgramGraphBuilder::AddConstant(const string& text) {
6767
return AddNode(Node::CONSTANT, text);
6868
}
6969

70+
Node* ProgramGraphBuilder::AddType(const string& text) {
71+
return AddNode(Node::TYPE, text);
72+
}
73+
7074
labm8::StatusOr<Edge*> ProgramGraphBuilder::AddControlEdge(int32_t position,
7175
const Node* source,
7276
const Node* target) {
@@ -143,6 +147,26 @@ labm8::StatusOr<Edge*> ProgramGraphBuilder::AddCallEdge(const Node* source,
143147
return AddEdge(Edge::CALL, /*position=*/0, source, target);
144148
}
145149

150+
labm8::StatusOr<Edge*> ProgramGraphBuilder::AddTypeEdge(int32_t position,
151+
const Node* source,
152+
const Node* target) {
153+
DCHECK(source) << "nullptr argument";
154+
DCHECK(target) << "nullptr argument";
155+
156+
if (source->type() != Node::TYPE) {
157+
return Status(labm8::error::Code::INVALID_ARGUMENT,
158+
"Invalid source type ({}) for type edge. Expected type",
159+
Node::Type_Name(source->type()));
160+
}
161+
if (target->type() == Node::INSTRUCTION) {
162+
return Status(labm8::error::Code::INVALID_ARGUMENT,
163+
"Invalid destination type (instruction) for type edge. "
164+
"Expected {variable,constant,type}");
165+
}
166+
167+
return AddEdge(Edge::TYPE, position, source, target);
168+
}
169+
146170
labm8::StatusOr<ProgramGraph> ProgramGraphBuilder::Build() {
147171
// Check that all nodes except the root are connected. The root is allowed to
148172
// have no connections in the case where it is an empty graph.

programl/graph/program_graph_builder.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class ProgramGraphBuilder {
6161

6262
Node* AddConstant(const string& text);
6363

64+
Node* AddType(const string& text);
65+
6466
// Edge factories.
6567
[[nodiscard]] labm8::StatusOr<Edge*> AddControlEdge(int32_t position,
6668
const Node* source,
@@ -73,6 +75,10 @@ class ProgramGraphBuilder {
7375
[[nodiscard]] labm8::StatusOr<Edge*> AddCallEdge(const Node* source,
7476
const Node* target);
7577

78+
[[nodiscard]] labm8::StatusOr<Edge*> AddTypeEdge(int32_t position,
79+
const Node* source,
80+
const Node* target);
81+
7682
const Node* GetRootNode() const { return &graph_.node(0); }
7783

7884
// Return the graph protocol buffer.
@@ -113,7 +119,7 @@ class ProgramGraphBuilder {
113119
int32_t GetIndex(const Function* function);
114120
int32_t GetIndex(const Node* node);
115121

116-
// Maps which covert store the index of objects in repeated field lists.
122+
// Maps that store the index of objects in repeated field lists.
117123
absl::flat_hash_map<Module*, int32_t> moduleIndices_;
118124
absl::flat_hash_map<Function*, int32_t> functionIndices_;
119125
absl::flat_hash_map<Node*, int32_t> nodeIndices_;

programl/ir/llvm/inst2vec_encoder.py

+7
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def Encode(
122122
# Add the node features.
123123
var_embedding = self.dictionary["!IDENTIFIER"]
124124
const_embedding = self.dictionary["!IMMEDIATE"]
125+
type_embedding = self.dictionary["!IMMEDIATE"] # Types are immediates
125126

126127
text_index = 0
127128
for node in proto.node:
@@ -143,6 +144,12 @@ def Encode(
143144
node.features.feature["inst2vec_embedding"].int64_list.value.append(
144145
const_embedding
145146
)
147+
elif node.type == node_pb2.Node.TYPE:
148+
node.features.feature["inst2vec_embedding"].int64_list.value.append(
149+
type_embedding
150+
)
151+
else:
152+
raise TypeError(f"Unknown node type {node}")
146153

147154
proto.features.feature["inst2vec_annotated"].int64_list.value.append(1)
148155
return proto

programl/ir/llvm/internal/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ cc_library(
4343
"@com_google_absl//absl/container:flat_hash_set",
4444
"@labm8//labm8/cpp:status_macros",
4545
"@labm8//labm8/cpp:statusor",
46+
"@labm8//labm8/cpp:logging",
4647
"@labm8//labm8/cpp:string",
4748
"@llvm//10.0.0",
4849
],

programl/ir/llvm/internal/program_graph_builder.cc

+100-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "absl/container/flat_hash_map.h"
2222
#include "absl/container/flat_hash_set.h"
23+
#include "labm8/cpp/logging.h"
2324
#include "labm8/cpp/status_macros.h"
2425
#include "labm8/cpp/string.h"
2526
#include "llvm/IR/BasicBlock.h"
@@ -357,29 +358,125 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(
357358
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
358359
const programl::Function* function) {
359360
const LlvmTextComponents text = textEncoder_.Encode(operand);
360-
Node* node = AddVariable(text.lhs_type, function);
361+
Node* node = AddVariable("var", function);
361362
node->set_block(blockCount_);
362363
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
363364

365+
compositeTypeParts_.clear(); // Reset after previous call.
366+
Node* type = GetOrCreateType(operand->getType());
367+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
368+
364369
return node;
365370
}
366371

367372
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
368373
const programl::Function* function) {
369374
const LlvmTextComponents text = textEncoder_.Encode(argument);
370-
Node* node = AddVariable(text.lhs_type, function);
375+
Node* node = AddVariable("var", function);
371376
node->set_block(blockCount_);
372377
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
373378

379+
compositeTypeParts_.clear(); // Reset after previous call.
380+
Node* type = GetOrCreateType(argument->getType());
381+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
382+
374383
return node;
375384
}
376385

377386
Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
378387
const LlvmTextComponents text = textEncoder_.Encode(constant);
379-
Node* node = AddConstant(text.lhs_type);
388+
Node* node = AddConstant("val");
380389
node->set_block(blockCount_);
381390
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
382391

392+
compositeTypeParts_.clear(); // Reset after previous call.
393+
Node* type = GetOrCreateType(constant->getType());
394+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
395+
396+
return node;
397+
}
398+
399+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::Type* type) {
400+
// Dispatch to the type-specific handlers.
401+
if (::llvm::dyn_cast<::llvm::StructType>(type)) {
402+
return AddLlvmType(::llvm::dyn_cast<::llvm::StructType>(type));
403+
} else if (::llvm::dyn_cast<::llvm::PointerType>(type)) {
404+
return AddLlvmType(::llvm::dyn_cast<::llvm::PointerType>(type));
405+
} else if (::llvm::dyn_cast<::llvm::FunctionType>(type)) {
406+
return AddLlvmType(::llvm::dyn_cast<::llvm::FunctionType>(type));
407+
} else if (::llvm::dyn_cast<::llvm::ArrayType>(type)) {
408+
return AddLlvmType(::llvm::dyn_cast<::llvm::ArrayType>(type));
409+
} else if (::llvm::dyn_cast<::llvm::VectorType>(type)) {
410+
return AddLlvmType(::llvm::dyn_cast<::llvm::VectorType>(type));
411+
} else {
412+
const LlvmTextComponents text = textEncoder_.Encode(type);
413+
Node *node = AddType(text.text);
414+
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
415+
return node;
416+
}
417+
}
418+
419+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::StructType* type) {
420+
Node* node = AddType("struct");
421+
compositeTypeParts_[type] = node;
422+
graph::AddScalarFeature(node, "llvm_string",
423+
AddString(textEncoder_.Encode(type).text));
424+
425+
// Add types for the struct elements, and add type edges.
426+
for (int i = 0; i < type->getNumElements(); ++i) {
427+
const auto& member = type->elements()[i];
428+
// Re-use the type if it already exists to prevent duplication of member
429+
// types.
430+
auto memberNode = GetOrCreateType(member);
431+
CHECK(AddTypeEdge(/*position=*/i, memberNode, node).ok());
432+
}
433+
434+
return node;
435+
}
436+
437+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::PointerType* type) {
438+
Node* node = AddType("*");
439+
graph::AddScalarFeature(node, "llvm_string",
440+
AddString(textEncoder_.Encode(type).text));
441+
442+
auto elementType = type->getElementType();
443+
auto parent = compositeTypeParts_.find(elementType);
444+
if (parent == compositeTypeParts_.end()) {
445+
// Re-use the type if it already exists to prevent duplication.
446+
auto elementNode = GetOrCreateType(type->getElementType());
447+
CHECK(AddTypeEdge(/*position=*/0, elementNode, node).ok());
448+
} else {
449+
// Bottom-out for self-referencing types.
450+
CHECK(AddTypeEdge(/*position=*/0, node, parent->second).ok());
451+
}
452+
453+
return node;
454+
}
455+
456+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::FunctionType* type) {
457+
Node* node = AddType("fn");
458+
graph::AddScalarFeature(node, "llvm_string",
459+
AddString(textEncoder_.Encode(type).text));
460+
return node;
461+
}
462+
463+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::ArrayType* type) {
464+
Node* node = AddType("[]");
465+
graph::AddScalarFeature(node, "llvm_string",
466+
AddString(textEncoder_.Encode(type).text));
467+
// Re-use the type if it already exists to prevent duplication.
468+
auto elementType = GetOrCreateType(type->getElementType());
469+
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
470+
return node;
471+
}
472+
473+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::VectorType* type) {
474+
Node* node = AddType("vector");
475+
graph::AddScalarFeature(node, "llvm_string",
476+
AddString(textEncoder_.Encode(type).text));
477+
// Re-use the type if it already exists to prevent duplication.
478+
auto elementType = GetOrCreateType(type->getElementType());
479+
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
383480
return node;
384481
}
385482

programl/ir/llvm/internal/program_graph_builder.h

+32
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
7171

7272
void Clear();
7373

74+
// Return the node representing a type. If no node already exists
75+
// for this type, a new node is created and added to the graph. In
76+
// the case of composite types, multiple new nodes may be added by
77+
// this call, and the root type returned.
78+
Node* GetOrCreateType(const ::llvm::Type* type);
79+
7480
protected:
7581
[[nodiscard]] labm8::StatusOr<FunctionEntryExits> VisitFunction(
7682
const ::llvm::Function& function, const Function* functionMessage);
@@ -90,6 +96,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
9096
Node* AddLlvmVariable(const ::llvm::Argument* argument,
9197
const Function* function);
9298
Node* AddLlvmConstant(const ::llvm::Constant* constant);
99+
Node* AddLlvmType(const ::llvm::Type* type);
100+
Node* AddLlvmType(const ::llvm::StructType* type);
101+
Node* AddLlvmType(const ::llvm::PointerType* type);
102+
Node* AddLlvmType(const ::llvm::FunctionType* type);
103+
Node* AddLlvmType(const ::llvm::ArrayType* type);
104+
Node* AddLlvmType(const ::llvm::VectorType* type);
93105

94106
// Add a string to the strings list and return its position.
95107
//
@@ -120,6 +132,26 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
120132
absl::flat_hash_map<string, int32_t> stringsListPositions_;
121133
// The underlying storage for the strings table.
122134
BytesList* stringsList_;
135+
136+
// A map from an LLVM type to the node message that represents it.
137+
absl::flat_hash_map<const ::llvm::Type*, Node*> types_;
138+
139+
// When adding a new type to the graph we need to know whether the type that
140+
// we are adding is part of a composite type that references itself. For
141+
// example:
142+
//
143+
// struct BinaryTree {
144+
// int data;
145+
// struct BinaryTree* left;
146+
// struct BinaryTree* right;
147+
// }
148+
//
149+
// When the recursive GetOrCreateType() resolves the "left" member, it needs
150+
// to know that the parent BinaryTree type has already been processed. This
151+
// map stores the Nodes corresponding to any parent structs that have been
152+
// already added in a call to GetOrCreateType(). It must be cleared between
153+
// calls.
154+
absl::flat_hash_map<const ::llvm::Type*, Node*> compositeTypeParts_;
123155
};
124156

125157
} // namespace internal

programl/ir/llvm/py/llvm_test.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def test_simple_ir():
5656
assert len(proto.module) == 1
5757
assert proto.module[0].name == "foo.c"
5858

59-
assert len(proto.node) == 6
6059
assert proto.node[0].text == "<root>"
60+
assert len(proto.node) == 7
6161
assert proto.node[0].type == node_pb2.Node.INSTRUCTION
6262

6363
assert proto.node[1].text == "add"
@@ -68,20 +68,24 @@ def test_simple_ir():
6868
assert proto.node[2].type == node_pb2.Node.INSTRUCTION
6969
assert NodeFullText(proto, proto.node[2]) == "ret i32 %3"
7070

71-
assert proto.node[3].text == "i32"
71+
assert proto.node[3].text == "var"
7272
assert proto.node[3].type == node_pb2.Node.VARIABLE
7373
assert NodeFullText(proto, proto.node[3]) == "i32 %3"
7474

75-
# Use startswith() to compare names for these last two variables as thier
76-
# order may differ.
7775
assert proto.node[4].text == "i32"
78-
assert proto.node[4].type == node_pb2.Node.VARIABLE
79-
assert NodeFullText(proto, proto.node[4]).startswith("i32 %")
76+
assert proto.node[4].type == node_pb2.Node.TYPE
77+
assert NodeFullText(proto, proto.node[4]) == "i32"
8078

81-
assert proto.node[5].text == "i32"
79+
# Use startswith() to compare names for these last two variables as thier
80+
# order may differ.
81+
assert proto.node[5].text == "var"
8282
assert proto.node[5].type == node_pb2.Node.VARIABLE
8383
assert NodeFullText(proto, proto.node[5]).startswith("i32 %")
8484

85+
assert proto.node[6].text == "var"
86+
assert proto.node[6].type == node_pb2.Node.VARIABLE
87+
assert NodeFullText(proto, proto.node[6]).startswith("i32 %")
88+
8589

8690
def test_opt_level():
8791
"""Test equivalence of nodes that pre-process to the same text."""

0 commit comments

Comments
 (0)