Skip to content

Commit 9df26f8

Browse files
committed
Add types to the graph.
This adds a fourth node type, and a fourth edge flow, both called "type". The idea is to represent types as first-class elements in the graph representation. This allows greater compositionality by breaking up composite types into subcomponents, and decreases the required vocabulary size required to achieve a given coverage. Background ---------- Currently, type information is stored in the "text" field of nodes for constants and variables, e.g.: node { type: VARIABLE text: "i8" } There are two issues with this: * Composite types end up with long textual representations, e.g. "struct foo { i32 a; i32 b; ... }". Since there is an unbounded number of possible structs, this prevents 100% vocabulary coverage on any IR with structs (or other composite types). * In the future, we will want to encode different information on data nodes, such as embedding literal values. Moving the type information out of the data node "frees up" space for something else. Overview -------- This changes the representation to represent types as first-class elements in the graph. A "type" node represents a type using its "text" field, and a new "type" edge connects this type to variables or constants of that type, e.g. a variable "int x" could be represented as: node { type: VARIABLE text: "var" } node { type: TYPE text: "i32" } edge { flow: TYPE source: 1 } Composite types --------------- Types may be composed by connecting multiple type nodes using type edges. This allows you to break down complex types into a graph of primitive parts. The meaning of composite types will depend on the IR being targetted, the remainder describes the process for LLVM-IR. Pointer types ------------- A pointer is a composite of two types: [variable] <- [pointer] <- [pointed-type] For example: int32_t* instance; Would be represented as: node { type: TYPE text: "i32" } node { type: TYPE text: "*" } node { type: VARIABLE text: "var" } edge { text: TYPE target: 1 } edge { text: TYPE source: 1 target: 2 } Where variables/constants of this type receive an incoming type edge from the [pointer] node, which in turn receives an incoming type edge from the [pointed-type] node. One [pointer] node is generated for each unique pointer type. If a graph contains multiple pointer types, there will be multiple [pointer] nodes, one for each pointed type. Struct types ------------ A struct is a compsite type where each member is a node type which points to the parent node. Variable/constant instances of a struct receive an incoming type edge from the root struct node. Note that the graph of type nodes representing a composite struct type may be cyclical, since a struct can contain a pointer of the same type (think of a binary tree implementation). For all other member types, a new type node is produced. For example, a struct with two integer members will produce two integer type nodes, they are not shared. The type edges from member nodes to the parent struct are positional. The position indicates the element number. E.g. for a struct with three elements, the incoming type edges to the struct node will have positions 0, 1, and 2. This example struct: struct s { int8_t a; int8_t b; struct s* c; } struct s instance; Would be represented as: node { type: TYPE text: "struct" } node { type: TYPE text: "i8" } node { type: TYPE text: "i8" } node { type: TYPE text: "*" } node { type: VARIABLE text: "var" } edge { flow: TYPE target: 1 } edge { flow: TYPE target: 2 position: 1 } edge { flow: TYPE target: 3 position: 2 } edge { flow: TYPE source: 3 } edge { flow: TYPE target: 4 } Array Types ----------- An array is a composite type [variable] <- [array] <- [element-type]. For example, the array: int a[10]; Would be represented as: node { type: TYPE text: "i32" } node { type: TYPE text: "[]" } node { type: VARIABLE text: "var" } edge { flow: TYPE target: 1 } edge { flow: TYPE source: 1 target: 2 } Function Pointers ----------------- A function pointer is represented by a type node that uniquely identifies the *signature* of a function, i.e. its return type and parameter types. The caveat of this is that pointers to different functions which have the same signature will resolve to the same type node. Additionally, there is no edge connecting a function pointer type and the instructions which belong to this function. github.com//issues/82
1 parent b2ef0af commit 9df26f8

11 files changed

+235
-24
lines changed

programl/graph/format/graphviz_converter.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -169,27 +169,31 @@ class GraphVizSerializer {
169169
template <typename T>
170170
void SetVertexAttributes(const Node& node, T& attributes) {
171171
attributes["label"] = GetNodeLabel(node);
172+
attributes["style"] = "filled";
172173
switch (node.type()) {
173174
case Node::INSTRUCTION:
174175
attributes["shape"] = "box";
175-
attributes["style"] = "filled";
176176
attributes["fillcolor"] = "#3c78d8";
177177
attributes["fontcolor"] = "#ffffff";
178178
break;
179179
case Node::VARIABLE:
180180
attributes["shape"] = "ellipse";
181-
attributes["style"] = "filled";
182181
attributes["fillcolor"] = "#f4cccc";
183182
attributes["color"] = "#990000";
184183
attributes["fontcolor"] = "#990000";
185184
break;
186185
case Node::CONSTANT:
187-
attributes["shape"] = "diamond";
188-
attributes["style"] = "filled";
186+
attributes["shape"] = "octagon";
189187
attributes["fillcolor"] = "#e99c9c";
190188
attributes["color"] = "#990000";
191189
attributes["fontcolor"] = "#990000";
192190
break;
191+
case Node::TYPE:
192+
attributes["shape"] = "diamond";
193+
attributes["fillcolor"] = "#cccccc";
194+
attributes["color"] = "#cccccc";
195+
attributes["fontcolor"] = "#222222";
196+
break;
193197
}
194198
}
195199

@@ -201,7 +205,7 @@ class GraphVizSerializer {
201205
const Node& node = graph_.node(i);
202206
// Determine the subgraph to add this node to.
203207
boost::subgraph<GraphvizGraph>* dst = defaultGraph;
204-
if (i && node.type() != Node::CONSTANT) {
208+
if (i && (node.type() == Node::INSTRUCTION || node.type() == Node::VARIABLE)) {
205209
dst = &(*functionGraphs)[node.function()].get();
206210
}
207211
auto vertex = add_vertex(i, *dst);
@@ -226,14 +230,20 @@ class GraphVizSerializer {
226230
attributes["color"] = "#65ae4d";
227231
attributes["weight"] = "1";
228232
break;
233+
case Edge::TYPE:
234+
attributes["color"] = "#aaaaaa";
235+
attributes["weight"] = "1";
236+
attributes["penwidth"] = "1.5";
237+
break;
229238
}
230239

231240
// Set the edge label.
232241
if (edge.position()) {
233242
// Position labels for control edge are drawn close to the originating
234-
// instruction. For data edges, they are drawn closer to the consuming
235-
// instruction.
236-
const string label = edge.flow() == Edge::DATA ? "headlabel" : "taillabel";
243+
// instruction. For control edges, they are drawn close to the branching
244+
// instruction. For data and type edges, they are drawn close to the
245+
// consuming node.
246+
const string label = edge.flow() == Edge::CONTROL ? "taillabel" : "headlabel";
237247
attributes[label] = std::to_string(edge.position());
238248
attributes["labelfontcolor"] = attributes["color"];
239249
}

programl/graph/program_graph_builder.cc

+21
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ Node* ProgramGraphBuilder::AddVariable(const string& text, const Function* funct
6464

6565
Node* ProgramGraphBuilder::AddConstant(const string& text) { return AddNode(Node::CONSTANT, text); }
6666

67+
Node* ProgramGraphBuilder::AddType(const string& text) { return AddNode(Node::TYPE, text); }
68+
6769
labm8::StatusOr<Edge*> ProgramGraphBuilder::AddControlEdge(int32_t position, const Node* source,
6870
const Node* target) {
6971
DCHECK(source) << "nullptr argument";
@@ -131,6 +133,25 @@ labm8::StatusOr<Edge*> ProgramGraphBuilder::AddCallEdge(const Node* source, cons
131133
return AddEdge(Edge::CALL, /*position=*/0, source, target);
132134
}
133135

136+
labm8::StatusOr<Edge*> ProgramGraphBuilder::AddTypeEdge(int32_t position, const Node* source,
137+
const Node* target) {
138+
DCHECK(source) << "nullptr argument";
139+
DCHECK(target) << "nullptr argument";
140+
141+
if (source->type() != Node::TYPE) {
142+
return Status(labm8::error::Code::INVALID_ARGUMENT,
143+
"Invalid source type ({}) for type edge. Expected type",
144+
Node::Type_Name(source->type()));
145+
}
146+
if (target->type() == Node::INSTRUCTION) {
147+
return Status(labm8::error::Code::INVALID_ARGUMENT,
148+
"Invalid destination type (instruction) for type edge. "
149+
"Expected {variable,constant,type}");
150+
}
151+
152+
return AddEdge(Edge::TYPE, position, source, target);
153+
}
154+
134155
labm8::StatusOr<ProgramGraph> ProgramGraphBuilder::Build() {
135156
if (options().strict()) {
136157
RETURN_IF_ERROR(ValidateGraph());

programl/graph/program_graph_builder.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ProgramGraphBuilder {
6464

6565
Node* AddConstant(const string& text);
6666

67+
Node* AddType(const string& text);
68+
6769
// Edge factories.
6870
[[nodiscard]] labm8::StatusOr<Edge*> AddControlEdge(int32_t position, const Node* source,
6971
const Node* target);
@@ -73,6 +75,9 @@ class ProgramGraphBuilder {
7375

7476
[[nodiscard]] labm8::StatusOr<Edge*> AddCallEdge(const Node* source, const Node* target);
7577

78+
[[nodiscard]] labm8::StatusOr<Edge*> AddTypeEdge(int32_t position, const Node* source,
79+
const Node* target);
80+
7681
const Node* GetRootNode() const { return &graph_.node(0); }
7782

7883
// Return the graph protocol buffer.
@@ -119,7 +124,7 @@ class ProgramGraphBuilder {
119124
int32_t GetIndex(const Function* function);
120125
int32_t GetIndex(const Node* node);
121126

122-
// Maps which covert store the index of objects in repeated field lists.
127+
// Maps that store the index of objects in repeated field lists.
123128
absl::flat_hash_map<Module*, int32_t> moduleIndices_;
124129
absl::flat_hash_map<Function*, int32_t> functionIndices_;
125130
absl::flat_hash_map<Node*, int32_t> nodeIndices_;

programl/ir/llvm/inst2vec_encoder.py

+7
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def Encode(
112112
# Add the node features.
113113
var_embedding = self.dictionary["!IDENTIFIER"]
114114
const_embedding = self.dictionary["!IMMEDIATE"]
115+
type_embedding = self.dictionary["!IMMEDIATE"] # Types are immediates
115116

116117
text_index = 0
117118
for node in proto.node:
@@ -133,6 +134,12 @@ def Encode(
133134
node.features.feature["inst2vec_embedding"].int64_list.value.append(
134135
const_embedding
135136
)
137+
elif node.type == node_pb2.Node.TYPE:
138+
node.features.feature["inst2vec_embedding"].int64_list.value.append(
139+
type_embedding
140+
)
141+
else:
142+
raise TypeError(f"Unknown node type {node}")
136143

137144
proto.features.feature["inst2vec_annotated"].int64_list.value.append(1)
138145
return proto

programl/ir/llvm/internal/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ cc_library(
4141
"//programl/proto:program_graph_options_cc",
4242
"@com_google_absl//absl/container:flat_hash_map",
4343
"@com_google_absl//absl/container:flat_hash_set",
44+
"@labm8//labm8/cpp:logging",
4445
"@labm8//labm8/cpp:status_macros",
4546
"@labm8//labm8/cpp:statusor",
4647
"@labm8//labm8/cpp:string",

programl/ir/llvm/internal/program_graph_builder.cc

+116-7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "absl/container/flat_hash_map.h"
2222
#include "absl/container/flat_hash_set.h"
23+
#include "labm8/cpp/logging.h"
2324
#include "labm8/cpp/status_macros.h"
2425
#include "labm8/cpp/string.h"
2526
#include "llvm/IR/BasicBlock.h"
@@ -39,12 +40,18 @@ namespace ir {
3940
namespace llvm {
4041
namespace internal {
4142

43+
namespace {
44+
45+
BytesList* getStringsList(ProgramGraph* programGraph) {
46+
return (*programGraph->mutable_features()->mutable_feature())["strings"].mutable_bytes_list();
47+
}
48+
49+
} // anonymous namespace
50+
4251
ProgramGraphBuilder::ProgramGraphBuilder(const ProgramGraphOptions& options)
43-
: programl::graph::ProgramGraphBuilder(),
44-
options_(options),
52+
: programl::graph::ProgramGraphBuilder(options),
4553
blockCount_(0),
46-
stringsList_((*GetMutableProgramGraph()->mutable_features()->mutable_feature())["strings"]
47-
.mutable_bytes_list()) {
54+
stringsList_(getStringsList(GetMutableProgramGraph())) {
4855
// Add an empty
4956
graph::AddScalarFeature(GetMutableRootNode(), "llvm_string", AddString(""));
5057
}
@@ -335,29 +342,131 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(const ::llvm::Instruction* instruc
335342
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
336343
const programl::Function* function) {
337344
const LlvmTextComponents text = textEncoder_.Encode(operand);
338-
Node* node = AddVariable(text.lhs_type, function);
345+
Node* node = AddVariable("var", function);
339346
node->set_block(blockCount_);
340347
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
341348

349+
compositeTypeParts_.clear(); // Reset after previous call.
350+
Node* type = GetOrCreateType(operand->getType());
351+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
352+
342353
return node;
343354
}
344355

345356
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
346357
const programl::Function* function) {
347358
const LlvmTextComponents text = textEncoder_.Encode(argument);
348-
Node* node = AddVariable(text.lhs_type, function);
359+
Node* node = AddVariable("var", function);
349360
node->set_block(blockCount_);
350361
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
351362

363+
compositeTypeParts_.clear(); // Reset after previous call.
364+
Node* type = GetOrCreateType(argument->getType());
365+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
366+
352367
return node;
353368
}
354369

355370
Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
356371
const LlvmTextComponents text = textEncoder_.Encode(constant);
357-
Node* node = AddConstant(text.lhs_type);
372+
Node* node = AddConstant("val");
358373
node->set_block(blockCount_);
359374
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
360375

376+
compositeTypeParts_.clear(); // Reset after previous call.
377+
Node* type = GetOrCreateType(constant->getType());
378+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
379+
380+
return node;
381+
}
382+
383+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::Type* type) {
384+
// Dispatch to the type-specific handlers.
385+
if (::llvm::dyn_cast<::llvm::StructType>(type)) {
386+
return AddLlvmType(::llvm::dyn_cast<::llvm::StructType>(type));
387+
} else if (::llvm::dyn_cast<::llvm::PointerType>(type)) {
388+
return AddLlvmType(::llvm::dyn_cast<::llvm::PointerType>(type));
389+
} else if (::llvm::dyn_cast<::llvm::FunctionType>(type)) {
390+
return AddLlvmType(::llvm::dyn_cast<::llvm::FunctionType>(type));
391+
} else if (::llvm::dyn_cast<::llvm::ArrayType>(type)) {
392+
return AddLlvmType(::llvm::dyn_cast<::llvm::ArrayType>(type));
393+
} else if (::llvm::dyn_cast<::llvm::VectorType>(type)) {
394+
return AddLlvmType(::llvm::dyn_cast<::llvm::VectorType>(type));
395+
} else {
396+
const LlvmTextComponents text = textEncoder_.Encode(type);
397+
Node* node = AddType(text.text);
398+
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
399+
return node;
400+
}
401+
}
402+
403+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::StructType* type) {
404+
Node* node = AddType("struct");
405+
compositeTypeParts_[type] = node;
406+
graph::AddScalarFeature(node, "llvm_string", AddString(textEncoder_.Encode(type).text));
407+
408+
// Add types for the struct elements, and add type edges.
409+
for (int i = 0; i < type->getNumElements(); ++i) {
410+
const auto& member = type->elements()[i];
411+
// Don't re-use member types in structs, always create a new type. For
412+
// example, the code:
413+
//
414+
// struct S {
415+
// int a;
416+
// int b;
417+
// };
418+
// int c;
419+
// int d;
420+
//
421+
// would produce four type nodes: one for S.a, one for S.b, and one which
422+
// is shared by c and d.
423+
Node* memberNode = AddLlvmType(member);
424+
CHECK(AddTypeEdge(/*position=*/i, memberNode, node).ok());
425+
}
426+
427+
return node;
428+
}
429+
430+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::PointerType* type) {
431+
Node* node = AddType("*");
432+
graph::AddScalarFeature(node, "llvm_string", AddString(textEncoder_.Encode(type).text));
433+
434+
auto elementType = type->getElementType();
435+
auto parent = compositeTypeParts_.find(elementType);
436+
if (parent == compositeTypeParts_.end()) {
437+
// Re-use the type if it already exists to prevent duplication.
438+
auto elementNode = GetOrCreateType(type->getElementType());
439+
CHECK(AddTypeEdge(/*position=*/0, elementNode, node).ok());
440+
} else {
441+
// Bottom-out for self-referencing types.
442+
CHECK(AddTypeEdge(/*position=*/0, parent->second, node).ok());
443+
}
444+
445+
return node;
446+
}
447+
448+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::FunctionType* type) {
449+
const std::string signature = textEncoder_.Encode(type).text;
450+
Node* node = AddType(signature);
451+
graph::AddScalarFeature(node, "llvm_string", AddString(signature));
452+
return node;
453+
}
454+
455+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::ArrayType* type) {
456+
Node* node = AddType("[]");
457+
graph::AddScalarFeature(node, "llvm_string", AddString(textEncoder_.Encode(type).text));
458+
// Re-use the type if it already exists to prevent duplication.
459+
auto elementType = GetOrCreateType(type->getElementType());
460+
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
461+
return node;
462+
}
463+
464+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::VectorType* type) {
465+
Node* node = AddType("vector");
466+
graph::AddScalarFeature(node, "llvm_string", AddString(textEncoder_.Encode(type).text));
467+
// Re-use the type if it already exists to prevent duplication.
468+
auto elementType = GetOrCreateType(type->getElementType());
469+
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
361470
return node;
362471
}
363472

programl/ir/llvm/internal/program_graph_builder.h

+32
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
7070

7171
void Clear();
7272

73+
// Return the node representing a type. If no node already exists
74+
// for this type, a new node is created and added to the graph. In
75+
// the case of composite types, multiple new nodes may be added by
76+
// this call, and the root type returned.
77+
Node* GetOrCreateType(const ::llvm::Type* type);
78+
7379
protected:
7480
[[nodiscard]] labm8::StatusOr<FunctionEntryExits> VisitFunction(const ::llvm::Function& function,
7581
const Function* functionMessage);
@@ -85,6 +91,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
8591
Node* AddLlvmVariable(const ::llvm::Instruction* operand, const Function* function);
8692
Node* AddLlvmVariable(const ::llvm::Argument* argument, const Function* function);
8793
Node* AddLlvmConstant(const ::llvm::Constant* constant);
94+
Node* AddLlvmType(const ::llvm::Type* type);
95+
Node* AddLlvmType(const ::llvm::StructType* type);
96+
Node* AddLlvmType(const ::llvm::PointerType* type);
97+
Node* AddLlvmType(const ::llvm::FunctionType* type);
98+
Node* AddLlvmType(const ::llvm::ArrayType* type);
99+
Node* AddLlvmType(const ::llvm::VectorType* type);
88100

89101
// Add a string to the strings list and return its position.
90102
//
@@ -112,6 +124,26 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
112124
absl::flat_hash_map<string, int32_t> stringsListPositions_;
113125
// The underlying storage for the strings table.
114126
BytesList* stringsList_;
127+
128+
// A map from an LLVM type to the node message that represents it.
129+
absl::flat_hash_map<const ::llvm::Type*, Node*> types_;
130+
131+
// When adding a new type to the graph we need to know whether the type that
132+
// we are adding is part of a composite type that references itself. For
133+
// example:
134+
//
135+
// struct BinaryTree {
136+
// int data;
137+
// struct BinaryTree* left;
138+
// struct BinaryTree* right;
139+
// }
140+
//
141+
// When the recursive GetOrCreateType() resolves the "left" member, it needs
142+
// to know that the parent BinaryTree type has already been processed. This
143+
// map stores the Nodes corresponding to any parent structs that have been
144+
// already added in a call to GetOrCreateType(). It must be cleared between
145+
// calls.
146+
absl::flat_hash_map<const ::llvm::Type*, Node*> compositeTypeParts_;
115147
};
116148

117149
} // namespace internal

programl/ir/llvm/py/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ py_test(
3939
srcs = ["llvm_test.py"],
4040
deps = [
4141
":llvm",
42+
"//programl/proto:edge_py",
4243
"//programl/proto:node_py",
4344
"//programl/proto:program_graph_options_py",
4445
"//programl/proto:program_graph_py",

0 commit comments

Comments
 (0)