Skip to content

Commit 16aa6c7

Browse files
committed
Add types to the graph.
This adds a fourth node type, and a fourth edge flow, both called "type". The idea is to represent types as first-class elements in the graph representation. This allows greater compositionality by breaking up composite types into subcomponents, and decreases the required vocabulary size required to achieve a given coverage. Background ---------- Currently, type information is stored in the "text" field of nodes for constants and variables, e.g.: node { type: VARIABLE text: "i8" } There are two issues with this: * Composite types end up with long textual representations, e.g. "struct foo { i32 a; i32 b; ... }". Since there is an unbounded number of possible structs, this prevents 100% vocabulary coverage on any IR with structs (or other composite types). * In the future, we will want to encode different information on data nodes, such as embedding literal values. Moving the type information out of the data node "frees up" space for something else. Overview -------- This changes the representation to represent types as first-class elements in the graph. A "type" node represents a type using its "text" field, and a new "type" edge connects this type to variables or constants of that type, e.g. a variable "int x" could be represented as: node { type: VARIABLE text: "var" } node { type: TYPE text: "i32" } edge { flow: TYPE source: 1 } Composite types --------------- Types may be composed by connecting multiple type nodes using type edges. This allows you to break down complex types into a graph of primitive parts. The meaning of composite types will depend on the IR being targetted, the remainder describes the process for LLVM-IR. Pointer types ------------- A pointer is a composite of two types: [variable] <- [pointer] <- [pointed-type] For example: int32_t* instance; Would be represented as: node { type: TYPE text: "i32" } node { type: TYPE text: "*" } node { type: VARIABLE text: "var" } edge { text: TYPE target: 1 } edge { text: TYPE source: 1 target: 2 } Where variables/constants of this type receive an incoming type edge from the [pointer] node, which in turn receives an incoming type edge from the [pointed-type] node. One [pointer] node is generated for each unique pointer type. If a graph contains multiple pointer types, there will be multiple [pointer] nodes, one for each pointed type. Struct types ------------ A struct is a compsite type where each member is a node type which points to the parent node. Variable/constant instances of a struct receive an incoming type edge from the root struct node. Note that the graph of type nodes representing a composite struct type may be cyclical, since a struct can contain a pointer of the same type (think of a binary tree implementation). For all other member types, a new type node is produced. For example, a struct with two integer members will produce two integer type nodes, they are not shared. The type edges from member nodes to the parent struct are positional. The position indicates the element number. E.g. for a struct with three elements, the incoming type edges to the struct node will have positions 0, 1, and 2. This example struct: struct s { int8_t a; int8_t b; struct s* c; } struct s instance; Would be represented as: node { type: TYPE text: "struct" } node { type: TYPE text: "i8" } node { type: TYPE text: "i8" } node { type: TYPE text: "*" } node { type: VARIABLE text: "var" } edge { flow: TYPE target: 1 } edge { flow: TYPE target: 2 position: 1 } edge { flow: TYPE target: 3 position: 2 } edge { flow: TYPE source: 3 } edge { flow: TYPE target: 4 } Array Types ----------- An array is a composite type [variable] <- [array] <- [element-type]. For example, the array: int a[10]; Would be represented as: node { type: TYPE text: "i32" } node { type: TYPE text: "[]" } node { type: VARIABLE text: "var" } edge { flow: TYPE target: 1 } edge { flow: TYPE source: 1 target: 2 } Function Pointers ----------------- A function pointer is represented by a type node that uniquely identifies the *signature* of a function, i.e. its return type and parameter types. The caveat of this is that pointers to different functions which have the same signature will resolve to the same type node. Additionally, there is no edge connecting a function pointer type and the instructions which belong to this function. github.com//issues/82
1 parent b785a1d commit 16aa6c7

7 files changed

+204
-12
lines changed

programl/graph/format/graphviz_converter.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -170,27 +170,31 @@ class GraphVizSerializer {
170170
template <typename T>
171171
void SetVertexAttributes(const Node& node, T& attributes) {
172172
attributes["label"] = GetNodeLabel(node);
173+
attributes["style"] = "filled";
173174
switch (node.type()) {
174175
case Node::INSTRUCTION:
175176
attributes["shape"] = "box";
176-
attributes["style"] = "filled";
177177
attributes["fillcolor"] = "#3c78d8";
178178
attributes["fontcolor"] = "#ffffff";
179179
break;
180180
case Node::VARIABLE:
181181
attributes["shape"] = "ellipse";
182-
attributes["style"] = "filled";
183182
attributes["fillcolor"] = "#f4cccc";
184183
attributes["color"] = "#990000";
185184
attributes["fontcolor"] = "#990000";
186185
break;
187186
case Node::CONSTANT:
188-
attributes["shape"] = "diamond";
189-
attributes["style"] = "filled";
187+
attributes["shape"] = "octagon";
190188
attributes["fillcolor"] = "#e99c9c";
191189
attributes["color"] = "#990000";
192190
attributes["fontcolor"] = "#990000";
193191
break;
192+
case Node::TYPE:
193+
attributes["shape"] = "diamond";
194+
attributes["fillcolor"] = "#cccccc";
195+
attributes["color"] = "#cccccc";
196+
attributes["fontcolor"] = "#222222";
197+
break;
194198
default:
195199
LOG(FATAL) << "unreachable";
196200
}
@@ -204,7 +208,7 @@ class GraphVizSerializer {
204208
const Node& node = graph_.node(i);
205209
// Determine the subgraph to add this node to.
206210
boost::subgraph<GraphvizGraph>* dst = defaultGraph;
207-
if (i && node.type() != Node::CONSTANT) {
211+
if (i && (node.type() == Node::INSTRUCTION || node.type() == Node::VARIABLE)) {
208212
dst = &(*functionGraphs)[node.function()].get();
209213
}
210214
auto vertex = add_vertex(i, *dst);
@@ -229,16 +233,22 @@ class GraphVizSerializer {
229233
attributes["color"] = "#65ae4d";
230234
attributes["weight"] = "1";
231235
break;
236+
case Edge::TYPE:
237+
attributes["color"] = "#aaaaaa";
238+
attributes["weight"] = "1";
239+
attributes["penwidth"] = "1.5";
240+
break;
232241
default:
233242
LOG(FATAL) << "unreachable";
234243
}
235244

236245
// Set the edge label.
237246
if (edge.position()) {
238247
// Position labels for control edge are drawn close to the originating
239-
// instruction. For data edges, they are drawn closer to the consuming
240-
// instruction.
241-
const string label = edge.flow() == Edge::DATA ? "headlabel" : "taillabel";
248+
// instruction. For control edges, they are drawn close to the branching
249+
// instruction. For data and type edges, they are drawn close to the
250+
// consuming node.
251+
const string label = edge.flow() == Edge::CONTROL ? "taillabel" : "headlabel";
242252
attributes[label] = std::to_string(edge.position());
243253
attributes["labelfontcolor"] = attributes["color"];
244254
}

programl/graph/program_graph_builder.cc

+21
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ Node* ProgramGraphBuilder::AddVariable(const string& text, const Function* funct
6464

6565
Node* ProgramGraphBuilder::AddConstant(const string& text) { return AddNode(Node::CONSTANT, text); }
6666

67+
Node* ProgramGraphBuilder::AddType(const string& text) { return AddNode(Node::TYPE, text); }
68+
6769
labm8::StatusOr<Edge*> ProgramGraphBuilder::AddControlEdge(int32_t position, const Node* source,
6870
const Node* target) {
6971
DCHECK(source) << "nullptr argument";
@@ -131,6 +133,25 @@ labm8::StatusOr<Edge*> ProgramGraphBuilder::AddCallEdge(const Node* source, cons
131133
return AddEdge(Edge::CALL, /*position=*/0, source, target);
132134
}
133135

136+
labm8::StatusOr<Edge*> ProgramGraphBuilder::AddTypeEdge(int32_t position, const Node* source,
137+
const Node* target) {
138+
DCHECK(source) << "nullptr argument";
139+
DCHECK(target) << "nullptr argument";
140+
141+
if (source->type() != Node::TYPE) {
142+
return Status(labm8::error::Code::INVALID_ARGUMENT,
143+
"Invalid source type ({}) for type edge. Expected type",
144+
Node::Type_Name(source->type()));
145+
}
146+
if (target->type() == Node::INSTRUCTION) {
147+
return Status(labm8::error::Code::INVALID_ARGUMENT,
148+
"Invalid destination type (instruction) for type edge. "
149+
"Expected {variable,constant,type}");
150+
}
151+
152+
return AddEdge(Edge::TYPE, position, source, target);
153+
}
154+
134155
labm8::StatusOr<ProgramGraph> ProgramGraphBuilder::Build() {
135156
if (options().strict()) {
136157
RETURN_IF_ERROR(ValidateGraph());

programl/graph/program_graph_builder.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ProgramGraphBuilder {
6464

6565
Node* AddConstant(const string& text);
6666

67+
Node* AddType(const string& text);
68+
6769
// Edge factories.
6870
[[nodiscard]] labm8::StatusOr<Edge*> AddControlEdge(int32_t position, const Node* source,
6971
const Node* target);
@@ -73,6 +75,9 @@ class ProgramGraphBuilder {
7375

7476
[[nodiscard]] labm8::StatusOr<Edge*> AddCallEdge(const Node* source, const Node* target);
7577

78+
[[nodiscard]] labm8::StatusOr<Edge*> AddTypeEdge(int32_t position, const Node* source,
79+
const Node* target);
80+
7681
const Node* GetRootNode() const { return &graph_.node(0); }
7782

7883
// Return the graph protocol buffer.
@@ -116,7 +121,7 @@ class ProgramGraphBuilder {
116121
int32_t GetIndex(const Function* function);
117122
int32_t GetIndex(const Node* node);
118123

119-
// Maps which covert store the index of objects in repeated field lists.
124+
// Maps that store the index of objects in repeated field lists.
120125
absl::flat_hash_map<Module*, int32_t> moduleIndices_;
121126
absl::flat_hash_map<Function*, int32_t> functionIndices_;
122127
absl::flat_hash_map<Node*, int32_t> nodeIndices_;

programl/ir/llvm/inst2vec_encoder.py

+7
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def Encode(self, proto: ProgramGraph, ir: Optional[str] = None) -> ProgramGraph:
9292
# Add the node features.
9393
var_embedding = self.dictionary["!IDENTIFIER"]
9494
const_embedding = self.dictionary["!IMMEDIATE"]
95+
type_embedding = self.dictionary["!IMMEDIATE"] # Types are immediates
9596

9697
text_index = 0
9798
for node in proto.node:
@@ -113,6 +114,12 @@ def Encode(self, proto: ProgramGraph, ir: Optional[str] = None) -> ProgramGraph:
113114
node.features.feature["inst2vec_embedding"].int64_list.value.append(
114115
const_embedding
115116
)
117+
elif node.type == node_pb2.Node.TYPE:
118+
node.features.feature["inst2vec_embedding"].int64_list.value.append(
119+
type_embedding
120+
)
121+
else:
122+
raise TypeError(f"Unknown node type {node}")
116123

117124
proto.features.feature["inst2vec_annotated"].int64_list.value.append(1)
118125
return proto

programl/ir/llvm/internal/program_graph_builder.cc

+116-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "absl/container/flat_hash_map.h"
2222
#include "absl/container/flat_hash_set.h"
23+
#include "labm8/cpp/logging.h"
2324
#include "labm8/cpp/status_macros.h"
2425
#include "labm8/cpp/string.h"
2526
#include "llvm/IR/BasicBlock.h"
@@ -323,29 +324,131 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(const ::llvm::Instruction* instruc
323324
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
324325
const programl::Function* function) {
325326
const LlvmTextComponents text = textEncoder_.Encode(operand);
326-
Node* node = AddVariable(text.lhs_type, function);
327+
Node* node = AddVariable("var", function);
327328
node->set_block(blockCount_);
328329
graph::AddScalarFeature(node, "full_text", text.lhs);
329330

331+
compositeTypeParts_.clear(); // Reset after previous call.
332+
Node* type = GetOrCreateType(operand->getType());
333+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
334+
330335
return node;
331336
}
332337

333338
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
334339
const programl::Function* function) {
335340
const LlvmTextComponents text = textEncoder_.Encode(argument);
336-
Node* node = AddVariable(text.lhs_type, function);
341+
Node* node = AddVariable("var", function);
337342
node->set_block(blockCount_);
338343
graph::AddScalarFeature(node, "full_text", text.lhs);
339344

345+
compositeTypeParts_.clear(); // Reset after previous call.
346+
Node* type = GetOrCreateType(argument->getType());
347+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
348+
340349
return node;
341350
}
342351

343352
Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
344353
const LlvmTextComponents text = textEncoder_.Encode(constant);
345-
Node* node = AddConstant(text.lhs_type);
354+
Node* node = AddConstant("val");
346355
node->set_block(blockCount_);
347356
graph::AddScalarFeature(node, "full_text", text.text);
348357

358+
compositeTypeParts_.clear(); // Reset after previous call.
359+
Node* type = GetOrCreateType(constant->getType());
360+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
361+
362+
return node;
363+
}
364+
365+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::Type* type) {
366+
// Dispatch to the type-specific handlers.
367+
if (::llvm::dyn_cast<::llvm::StructType>(type)) {
368+
return AddLlvmType(::llvm::dyn_cast<::llvm::StructType>(type));
369+
} else if (::llvm::dyn_cast<::llvm::PointerType>(type)) {
370+
return AddLlvmType(::llvm::dyn_cast<::llvm::PointerType>(type));
371+
} else if (::llvm::dyn_cast<::llvm::FunctionType>(type)) {
372+
return AddLlvmType(::llvm::dyn_cast<::llvm::FunctionType>(type));
373+
} else if (::llvm::dyn_cast<::llvm::ArrayType>(type)) {
374+
return AddLlvmType(::llvm::dyn_cast<::llvm::ArrayType>(type));
375+
} else if (::llvm::dyn_cast<::llvm::VectorType>(type)) {
376+
return AddLlvmType(::llvm::dyn_cast<::llvm::VectorType>(type));
377+
} else {
378+
const LlvmTextComponents text = textEncoder_.Encode(type);
379+
Node* node = AddType(text.text);
380+
graph::AddScalarFeature(node, "llvm_string", text.text);
381+
return node;
382+
}
383+
}
384+
385+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::StructType* type) {
386+
Node* node = AddType("struct");
387+
compositeTypeParts_[type] = node;
388+
graph::AddScalarFeature(node, "llvm_string", textEncoder_.Encode(type).text);
389+
390+
// Add types for the struct elements, and add type edges.
391+
for (int i = 0; i < type->getNumElements(); ++i) {
392+
const auto& member = type->elements()[i];
393+
// Don't re-use member types in structs, always create a new type. For
394+
// example, the code:
395+
//
396+
// struct S {
397+
// int a;
398+
// int b;
399+
// };
400+
// int c;
401+
// int d;
402+
//
403+
// would produce four type nodes: one for S.a, one for S.b, and one which
404+
// is shared by c and d.
405+
Node* memberNode = AddLlvmType(member);
406+
CHECK(AddTypeEdge(/*position=*/i, memberNode, node).ok());
407+
}
408+
409+
return node;
410+
}
411+
412+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::PointerType* type) {
413+
Node* node = AddType("*");
414+
graph::AddScalarFeature(node, "llvm_string", textEncoder_.Encode(type).text);
415+
416+
auto elementType = type->getElementType();
417+
auto parent = compositeTypeParts_.find(elementType);
418+
if (parent == compositeTypeParts_.end()) {
419+
// Re-use the type if it already exists to prevent duplication.
420+
auto elementNode = GetOrCreateType(type->getElementType());
421+
CHECK(AddTypeEdge(/*position=*/0, elementNode, node).ok());
422+
} else {
423+
// Bottom-out for self-referencing types.
424+
CHECK(AddTypeEdge(/*position=*/0, parent->second, node).ok());
425+
}
426+
427+
return node;
428+
}
429+
430+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::FunctionType* type) {
431+
const std::string signature = textEncoder_.Encode(type).text;
432+
Node* node = AddType(signature);
433+
graph::AddScalarFeature(node, "llvm_string", signature);
434+
return node;
435+
}
436+
437+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::ArrayType* type) {
438+
Node* node = AddType("[]");
439+
graph::AddScalarFeature(node, "llvm_string", textEncoder_.Encode(type).text);
440+
// Re-use the type if it already exists to prevent duplication.
441+
auto elementType = GetOrCreateType(type->getElementType());
442+
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
443+
return node;
444+
}
445+
446+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::VectorType* type) {
447+
Node* node = AddType("vector");
448+
graph::AddScalarFeature(node, "llvm_string", textEncoder_.Encode(type).text);
449+
// Re-use the type if it already exists to prevent duplication.
450+
auto elementType = GetOrCreateType(type->getElementType());
451+
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
349452
return node;
350453
}
351454

@@ -461,6 +564,16 @@ void ProgramGraphBuilder::Clear() {
461564
programl::graph::ProgramGraphBuilder::Clear();
462565
}
463566

567+
Node* ProgramGraphBuilder::GetOrCreateType(const ::llvm::Type* type) {
568+
auto it = types_.find(type);
569+
if (it == types_.end()) {
570+
Node* node = AddLlvmType(type);
571+
types_[type] = node;
572+
return node;
573+
}
574+
return it->second;
575+
}
576+
464577
} // namespace internal
465578
} // namespace llvm
466579
} // namespace ir

programl/ir/llvm/internal/program_graph_builder.h

+32
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
7070

7171
void Clear();
7272

73+
// Return the node representing a type. If no node already exists
74+
// for this type, a new node is created and added to the graph. In
75+
// the case of composite types, multiple new nodes may be added by
76+
// this call, and the root type returned.
77+
Node* GetOrCreateType(const ::llvm::Type* type);
78+
7379
protected:
7480
[[nodiscard]] labm8::StatusOr<FunctionEntryExits> VisitFunction(const ::llvm::Function& function,
7581
const Function* functionMessage);
@@ -85,6 +91,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
8591
Node* AddLlvmVariable(const ::llvm::Instruction* operand, const Function* function);
8692
Node* AddLlvmVariable(const ::llvm::Argument* argument, const Function* function);
8793
Node* AddLlvmConstant(const ::llvm::Constant* constant);
94+
Node* AddLlvmType(const ::llvm::Type* type);
95+
Node* AddLlvmType(const ::llvm::StructType* type);
96+
Node* AddLlvmType(const ::llvm::PointerType* type);
97+
Node* AddLlvmType(const ::llvm::FunctionType* type);
98+
Node* AddLlvmType(const ::llvm::ArrayType* type);
99+
Node* AddLlvmType(const ::llvm::VectorType* type);
88100

89101
private:
90102
TextEncoder textEncoder_;
@@ -99,6 +111,26 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
99111
// populated by VisitBasicBlock() and consumed once all functions have been
100112
// visited.
101113
absl::flat_hash_map<const ::llvm::Constant*, std::vector<PositionalNode>> constants_;
114+
115+
// A map from an LLVM type to the node message that represents it.
116+
absl::flat_hash_map<const ::llvm::Type*, Node*> types_;
117+
118+
// When adding a new type to the graph we need to know whether the type that
119+
// we are adding is part of a composite type that references itself. For
120+
// example:
121+
//
122+
// struct BinaryTree {
123+
// int data;
124+
// struct BinaryTree* left;
125+
// struct BinaryTree* right;
126+
// }
127+
//
128+
// When the recursive GetOrCreateType() resolves the "left" member, it needs
129+
// to know that the parent BinaryTree type has already been processed. This
130+
// map stores the Nodes corresponding to any parent structs that have been
131+
// already added in a call to GetOrCreateType(). It must be cleared between
132+
// calls.
133+
absl::flat_hash_map<const ::llvm::Type*, Node*> compositeTypeParts_;
102134
};
103135

104136
} // namespace internal

programl/proto/program_graph.proto

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ message Node {
5555
VARIABLE = 1;
5656
// A constant.
5757
CONSTANT = 2;
58+
// A type.
59+
TYPE = 3;
5860
}
5961
// The type of the node.
6062
Type type = 1;
@@ -92,6 +94,8 @@ message Edge {
9294
DATA = 1;
9395
// A call relation.
9496
CALL = 2;
97+
// A type relation.
98+
TYPE = 3;
9599
}
96100
// The type of relation of this edge.
97101
Flow flow = 1;

0 commit comments

Comments
 (0)