Skip to content

Commit 559f37f

Browse files
committed
Add types to the graph.
This adds a fourth node type, and a fourth edge flow, both called "type". The idea is to represent types as first-class elements in the graph representation. This allows greater compositionality by breaking up composite types into subcomponents, and decreases the required vocabulary size required to achieve a given coverage. Background ---------- Currently, type information is stored in the "text" field of nodes for constants and variables, e.g.: node { type: VARIABLE text: "i8" } There are two issues with this: * Composite types end up with long textual representations, e.g. "struct foo { i32 a; i32 b; ... }". Since there is an unbounded number of possible structs, this prevents 100% vocabulary coverage on any IR with structs (or other composite types). * In the future, we will want to encode different information on data nodes, such as embedding literal values. Moving the type information out of the data node "frees up" space for something else. Overview -------- This changes the representation to represent types as first-class elements in the graph. A "type" node represents a type using its "text" field, and a new "type" edge connects this type to variables or constants of that type, e.g. a variable "int x" could be represented as: node { type: VARIABLE text: "var" } node { type: TYPE text: "i32" } edge { flow: TYPE source: 1 } Composite types --------------- Types may be composed by connecting multiple type nodes using type edges. This allows you to break down complex types into a graph of primitive parts. The meaning of composite types will depend on the IR being targetted, the remainder describes the process for LLVM-IR. Pointer types ------------- A pointer is a composite of two types: [variable] <- [pointer] <- [pointed-type] For example: int32_t* instance; Would be represented as: node { type: TYPE text: "i32" } node { type: TYPE text: "*" } node { type: VARIABLE text: "var" } edge { text: TYPE target: 1 } edge { text: TYPE source: 1 target: 2 } Where variables/constants of this type receive an incoming type edge from the [pointer] node, which in turn receives an incoming type edge from the [pointed-type] node. One [pointer] node is generated for each unique pointer type. If a graph contains multiple pointer types, there will be multiple [pointer] nodes, one for each pointed type. Struct types ------------ A struct is a compsite type where each member is a node type which points to the parent node. Variable/constant instances of a struct receive an incoming type edge from the root struct node. Note that the graph of type nodes representing a composite struct type may be cyclical, since a struct can contain a pointer of the same type (think of a binary tree implementation). The type edges from member nodes to the parent struct are positional. The position indicates the element number. E.g. for a struct with three elements, the incoming type edges to the struct node will have positions 0, 1, and 2. This example struct: struct s { int8_t a; int8_t b; struct s* c; } struct s instance; Would be represented as: node { type: TYPE text: "struct" } node { type: TYPE text: "i8" } node { type: TYPE text: "*" } node { type: VARIABLE text: "var" } edge { flow: TYPE target: 1 } edge { flow: TYPE target: 1 position: 1 } edge { flow: TYPE target: 2 position: 2 } edge { flow: TYPE source: 2 } edge { flow: TYPE target: 3 } Array Types ----------- An array is a composite type [variable] <- [array] <- [element-type]. For example, the array: int a[10]; Would be represented as: node { type: TYPE text: "i32" } node { type: TYPE text: "[]" } node { type: VARIABLE text: "var" } edge { flow: TYPE target: 1 } edge { flow: TYPE source: 1 target: 2 } github.com//issues/82
1 parent e2cb3e1 commit 559f37f

11 files changed

+228
-25
lines changed

programl/graph/format/graphviz_converter.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ labm8::Status SerializeGraphVizToString(const ProgramGraph& graph,
134134

135135
// Determine the subgraph to add this node to.
136136
boost::subgraph<GraphvizGraph>* dst = &external;
137-
if (i && node.type() != Node::CONSTANT) {
137+
if (i && (node.type() == Node::INSTRUCTION || node.type() == Node::VARIABLE)) {
138138
dst = &functionGraphs[node.function()].get();
139139
}
140140

@@ -192,29 +192,33 @@ labm8::Status SerializeGraphVizToString(const ProgramGraph& graph,
192192
}
193193
labm8::TruncateWithEllipsis(text, kMaximumLabelLen);
194194
attributes["label"] = text;
195+
attributes["style"] = "filled";
195196

196197
// Set the node shape.
197198
switch (node.type()) {
198199
case Node::INSTRUCTION:
199200
attributes["shape"] = "box";
200-
attributes["style"] = "filled";
201201
attributes["fillcolor"] = "#3c78d8";
202202
attributes["fontcolor"] = "#ffffff";
203203
break;
204204
case Node::VARIABLE:
205205
attributes["shape"] = "ellipse";
206-
attributes["style"] = "filled";
207206
attributes["fillcolor"] = "#f4cccc";
208207
attributes["color"] = "#990000";
209208
attributes["fontcolor"] = "#990000";
210209
break;
211210
case Node::CONSTANT:
212-
attributes["shape"] = "diamond";
213-
attributes["style"] = "filled";
211+
attributes["shape"] = "octagon";
214212
attributes["fillcolor"] = "#e99c9c";
215213
attributes["color"] = "#990000";
216214
attributes["fontcolor"] = "#990000";
217215
break;
216+
case Node::TYPE:
217+
attributes["shape"] = "diamond";
218+
attributes["fillcolor"] = "#cccccc";
219+
attributes["color"] = "#cccccc";
220+
attributes["fontcolor"] = "#222222";
221+
break;
218222
}
219223
}
220224

@@ -242,15 +246,21 @@ labm8::Status SerializeGraphVizToString(const ProgramGraph& graph,
242246
attributes["color"] = "#65ae4d";
243247
attributes["weight"] = "1";
244248
break;
249+
case Edge::TYPE:
250+
attributes["color"] = "#aaaaaa";
251+
attributes["weight"] = "1";
252+
attributes["penwidth"] = "1.5";
253+
break;
245254
}
246255

247256
// Set the edge label.
248257
if (edge.position()) {
249258
// Position labels for control edge are drawn close to the originating
250-
// instruction. For data edges, they are drawn closer to the consuming
251-
// instruction.
259+
// instruction. For control edges, they are drawn close to the branching
260+
// instruction. For data and type edges, they are drawn close to the
261+
// consuming node.
252262
const string label =
253-
edge.flow() == Edge::DATA ? "headlabel" : "taillabel";
263+
edge.flow() == Edge::CONTROL ? "taillabel" : "headlabel";
254264
attributes[label] = std::to_string(edge.position());
255265
attributes["labelfontcolor"] = attributes["color"];
256266
}

programl/graph/program_graph_builder.cc

+24
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ Node* ProgramGraphBuilder::AddConstant(const string& text) {
6969
return AddNode(Node::CONSTANT, text);
7070
}
7171

72+
Node* ProgramGraphBuilder::AddType(const string& text) {
73+
return AddNode(Node::TYPE, text);
74+
}
75+
7276
labm8::StatusOr<Edge*> ProgramGraphBuilder::AddControlEdge(int32_t position,
7377
const Node* source,
7478
const Node* target) {
@@ -145,6 +149,26 @@ labm8::StatusOr<Edge*> ProgramGraphBuilder::AddCallEdge(const Node* source,
145149
return AddEdge(Edge::CALL, /*position=*/0, source, target);
146150
}
147151

152+
labm8::StatusOr<Edge*> ProgramGraphBuilder::AddTypeEdge(int32_t position,
153+
const Node* source,
154+
const Node* target) {
155+
DCHECK(source) << "nullptr argument";
156+
DCHECK(target) << "nullptr argument";
157+
158+
if (source->type() != Node::TYPE) {
159+
return Status(labm8::error::Code::INVALID_ARGUMENT,
160+
"Invalid source type ({}) for type edge. Expected type",
161+
Node::Type_Name(source->type()));
162+
}
163+
if (target->type() == Node::INSTRUCTION) {
164+
return Status(labm8::error::Code::INVALID_ARGUMENT,
165+
"Invalid destination type (instruction) for type edge. "
166+
"Expected {variable,constant,type}");
167+
}
168+
169+
return AddEdge(Edge::TYPE, position, source, target);
170+
}
171+
148172
labm8::StatusOr<ProgramGraph> ProgramGraphBuilder::Build() {
149173
if (options().strict()) {
150174
RETURN_IF_ERROR(ValidateGraph());

programl/graph/program_graph_builder.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ProgramGraphBuilder {
6464

6565
Node* AddConstant(const string& text);
6666

67+
Node* AddType(const string& text);
68+
6769
// Edge factories.
6870
[[nodiscard]] labm8::StatusOr<Edge*> AddControlEdge(int32_t position,
6971
const Node* source,
@@ -76,6 +78,10 @@ class ProgramGraphBuilder {
7678
[[nodiscard]] labm8::StatusOr<Edge*> AddCallEdge(const Node* source,
7779
const Node* target);
7880

81+
[[nodiscard]] labm8::StatusOr<Edge*> AddTypeEdge(int32_t position,
82+
const Node* source,
83+
const Node* target);
84+
7985
const Node* GetRootNode() const { return &graph_.node(0); }
8086

8187
// Return the graph protocol buffer.
@@ -123,7 +129,7 @@ class ProgramGraphBuilder {
123129
int32_t GetIndex(const Function* function);
124130
int32_t GetIndex(const Node* node);
125131

126-
// Maps which covert store the index of objects in repeated field lists.
132+
// Maps that store the index of objects in repeated field lists.
127133
absl::flat_hash_map<Module*, int32_t> moduleIndices_;
128134
absl::flat_hash_map<Function*, int32_t> functionIndices_;
129135
absl::flat_hash_map<Node*, int32_t> nodeIndices_;

programl/ir/llvm/inst2vec_encoder.py

+7
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def Encode(
122122
# Add the node features.
123123
var_embedding = self.dictionary["!IDENTIFIER"]
124124
const_embedding = self.dictionary["!IMMEDIATE"]
125+
type_embedding = self.dictionary["!IMMEDIATE"] # Types are immediates
125126

126127
text_index = 0
127128
for node in proto.node:
@@ -143,6 +144,12 @@ def Encode(
143144
node.features.feature["inst2vec_embedding"].int64_list.value.append(
144145
const_embedding
145146
)
147+
elif node.type == node_pb2.Node.TYPE:
148+
node.features.feature["inst2vec_embedding"].int64_list.value.append(
149+
type_embedding
150+
)
151+
else:
152+
raise TypeError(f"Unknown node type {node}")
146153

147154
proto.features.feature["inst2vec_annotated"].int64_list.value.append(1)
148155
return proto

programl/ir/llvm/internal/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ cc_library(
4343
"@com_google_absl//absl/container:flat_hash_set",
4444
"@labm8//labm8/cpp:status_macros",
4545
"@labm8//labm8/cpp:statusor",
46+
"@labm8//labm8/cpp:logging",
4647
"@labm8//labm8/cpp:string",
4748
"@llvm//10.0.0",
4849
],

programl/ir/llvm/internal/program_graph_builder.cc

+110-9
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "absl/container/flat_hash_map.h"
2222
#include "absl/container/flat_hash_set.h"
23+
#include "labm8/cpp/logging.h"
2324
#include "labm8/cpp/status_macros.h"
2425
#include "labm8/cpp/string.h"
2526
#include "llvm/IR/BasicBlock.h"
@@ -39,14 +40,18 @@ namespace ir {
3940
namespace llvm {
4041
namespace internal {
4142

43+
namespace {
44+
45+
BytesList* getStringsList(ProgramGraph* programGraph) {
46+
return (*programGraph->mutable_features()->mutable_feature())["strings"].mutable_bytes_list();
47+
}
48+
49+
} // anonymous namespace
50+
4251
ProgramGraphBuilder::ProgramGraphBuilder(const ProgramGraphOptions& options)
43-
: programl::graph::ProgramGraphBuilder(),
44-
options_(options),
52+
: programl::graph::ProgramGraphBuilder(options),
4553
blockCount_(0),
46-
stringsList_((*GetMutableProgramGraph()
47-
->mutable_features()
48-
->mutable_feature())["strings"]
49-
.mutable_bytes_list()) {
54+
stringsList_(getStringsList(GetMutableProgramGraph())) {
5055
// Add an empty
5156
graph::AddScalarFeature(GetMutableRootNode(), "llvm_string", AddString(""));
5257
}
@@ -357,29 +362,125 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(
357362
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
358363
const programl::Function* function) {
359364
const LlvmTextComponents text = textEncoder_.Encode(operand);
360-
Node* node = AddVariable(text.lhs_type, function);
365+
Node* node = AddVariable("var", function);
361366
node->set_block(blockCount_);
362367
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
363368

369+
compositeTypeParts_.clear(); // Reset after previous call.
370+
Node* type = GetOrCreateType(operand->getType());
371+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
372+
364373
return node;
365374
}
366375

367376
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
368377
const programl::Function* function) {
369378
const LlvmTextComponents text = textEncoder_.Encode(argument);
370-
Node* node = AddVariable(text.lhs_type, function);
379+
Node* node = AddVariable("var", function);
371380
node->set_block(blockCount_);
372381
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
373382

383+
compositeTypeParts_.clear(); // Reset after previous call.
384+
Node* type = GetOrCreateType(argument->getType());
385+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
386+
374387
return node;
375388
}
376389

377390
Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
378391
const LlvmTextComponents text = textEncoder_.Encode(constant);
379-
Node* node = AddConstant(text.lhs_type);
392+
Node* node = AddConstant("val");
380393
node->set_block(blockCount_);
381394
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
382395

396+
compositeTypeParts_.clear(); // Reset after previous call.
397+
Node* type = GetOrCreateType(constant->getType());
398+
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());
399+
400+
return node;
401+
}
402+
403+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::Type* type) {
404+
// Dispatch to the type-specific handlers.
405+
if (::llvm::dyn_cast<::llvm::StructType>(type)) {
406+
return AddLlvmType(::llvm::dyn_cast<::llvm::StructType>(type));
407+
} else if (::llvm::dyn_cast<::llvm::PointerType>(type)) {
408+
return AddLlvmType(::llvm::dyn_cast<::llvm::PointerType>(type));
409+
} else if (::llvm::dyn_cast<::llvm::FunctionType>(type)) {
410+
return AddLlvmType(::llvm::dyn_cast<::llvm::FunctionType>(type));
411+
} else if (::llvm::dyn_cast<::llvm::ArrayType>(type)) {
412+
return AddLlvmType(::llvm::dyn_cast<::llvm::ArrayType>(type));
413+
} else if (::llvm::dyn_cast<::llvm::VectorType>(type)) {
414+
return AddLlvmType(::llvm::dyn_cast<::llvm::VectorType>(type));
415+
} else {
416+
const LlvmTextComponents text = textEncoder_.Encode(type);
417+
Node *node = AddType(text.text);
418+
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
419+
return node;
420+
}
421+
}
422+
423+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::StructType* type) {
424+
Node* node = AddType("struct");
425+
compositeTypeParts_[type] = node;
426+
graph::AddScalarFeature(node, "llvm_string",
427+
AddString(textEncoder_.Encode(type).text));
428+
429+
// Add types for the struct elements, and add type edges.
430+
for (int i = 0; i < type->getNumElements(); ++i) {
431+
const auto& member = type->elements()[i];
432+
// Re-use the type if it already exists to prevent duplication of member
433+
// types.
434+
auto memberNode = GetOrCreateType(member);
435+
CHECK(AddTypeEdge(/*position=*/i, memberNode, node).ok());
436+
}
437+
438+
return node;
439+
}
440+
441+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::PointerType* type) {
442+
Node* node = AddType("*");
443+
graph::AddScalarFeature(node, "llvm_string",
444+
AddString(textEncoder_.Encode(type).text));
445+
446+
auto elementType = type->getElementType();
447+
auto parent = compositeTypeParts_.find(elementType);
448+
if (parent == compositeTypeParts_.end()) {
449+
// Re-use the type if it already exists to prevent duplication.
450+
auto elementNode = GetOrCreateType(type->getElementType());
451+
CHECK(AddTypeEdge(/*position=*/0, elementNode, node).ok());
452+
} else {
453+
// Bottom-out for self-referencing types.
454+
CHECK(AddTypeEdge(/*position=*/0, node, parent->second).ok());
455+
}
456+
457+
return node;
458+
}
459+
460+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::FunctionType* type) {
461+
Node* node = AddType("fn");
462+
graph::AddScalarFeature(node, "llvm_string",
463+
AddString(textEncoder_.Encode(type).text));
464+
return node;
465+
}
466+
467+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::ArrayType* type) {
468+
Node* node = AddType("[]");
469+
graph::AddScalarFeature(node, "llvm_string",
470+
AddString(textEncoder_.Encode(type).text));
471+
// Re-use the type if it already exists to prevent duplication.
472+
auto elementType = GetOrCreateType(type->getElementType());
473+
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
474+
return node;
475+
}
476+
477+
Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::VectorType* type) {
478+
Node* node = AddType("vector");
479+
graph::AddScalarFeature(node, "llvm_string",
480+
AddString(textEncoder_.Encode(type).text));
481+
// Re-use the type if it already exists to prevent duplication.
482+
auto elementType = GetOrCreateType(type->getElementType());
483+
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
383484
return node;
384485
}
385486

programl/ir/llvm/internal/program_graph_builder.h

+32
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
7171

7272
void Clear();
7373

74+
// Return the node representing a type. If no node already exists
75+
// for this type, a new node is created and added to the graph. In
76+
// the case of composite types, multiple new nodes may be added by
77+
// this call, and the root type returned.
78+
Node* GetOrCreateType(const ::llvm::Type* type);
79+
7480
protected:
7581
[[nodiscard]] labm8::StatusOr<FunctionEntryExits> VisitFunction(
7682
const ::llvm::Function& function, const Function* functionMessage);
@@ -90,6 +96,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
9096
Node* AddLlvmVariable(const ::llvm::Argument* argument,
9197
const Function* function);
9298
Node* AddLlvmConstant(const ::llvm::Constant* constant);
99+
Node* AddLlvmType(const ::llvm::Type* type);
100+
Node* AddLlvmType(const ::llvm::StructType* type);
101+
Node* AddLlvmType(const ::llvm::PointerType* type);
102+
Node* AddLlvmType(const ::llvm::FunctionType* type);
103+
Node* AddLlvmType(const ::llvm::ArrayType* type);
104+
Node* AddLlvmType(const ::llvm::VectorType* type);
93105

94106
// Add a string to the strings list and return its position.
95107
//
@@ -118,6 +130,26 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
118130
absl::flat_hash_map<string, int32_t> stringsListPositions_;
119131
// The underlying storage for the strings table.
120132
BytesList* stringsList_;
133+
134+
// A map from an LLVM type to the node message that represents it.
135+
absl::flat_hash_map<const ::llvm::Type*, Node*> types_;
136+
137+
// When adding a new type to the graph we need to know whether the type that
138+
// we are adding is part of a composite type that references itself. For
139+
// example:
140+
//
141+
// struct BinaryTree {
142+
// int data;
143+
// struct BinaryTree* left;
144+
// struct BinaryTree* right;
145+
// }
146+
//
147+
// When the recursive GetOrCreateType() resolves the "left" member, it needs
148+
// to know that the parent BinaryTree type has already been processed. This
149+
// map stores the Nodes corresponding to any parent structs that have been
150+
// already added in a call to GetOrCreateType(). It must be cleared between
151+
// calls.
152+
absl::flat_hash_map<const ::llvm::Type*, Node*> compositeTypeParts_;
121153
};
122154

123155
} // namespace internal

programl/ir/llvm/py/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ py_test(
3939
srcs = ["llvm_test.py"],
4040
deps = [
4141
":llvm",
42+
"//programl/proto:edge_py",
4243
"//programl/proto:node_py",
4344
"//programl/proto:program_graph_options_py",
4445
"//programl/proto:program_graph_py",

0 commit comments

Comments
 (0)