Skip to content

Add types to graph (cherry-picked commit) #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/pre_commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ jobs:
go-version: ^1.13.1
- name: Install buildifier
run: |
go get github.com/bazelbuild/buildtools/buildifier
go install github.com/bazelbuild/buildtools/buildifier@latest
buildifier --version
- name: Install prototool
run: |
GO111MODULE=on go get github.com/uber/prototool/cmd/prototool@dev
GO111MODULE=on go install github.com/uber/prototool/cmd/prototool@dev
prototool version
- name: Install Python 3.8
uses: actions/setup-python@v2
Expand All @@ -43,7 +43,6 @@ jobs:
python3 -m pip install --upgrade wheel
python3 -m pip install -r tools/requirements.pre_commit.txt
python3 -m isort --version
python3 -m black --version
python3 -m pre_commit --version
- name: Run pre-commit checks
# TODO(github.com/facebookresearch/CompilerGym/issues/1): Disable
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 20.8b1
rev: 22.3.0
hooks:
- id: black
language_version: python3.8
Expand Down
26 changes: 18 additions & 8 deletions programl/graph/format/graphviz_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,31 @@ class GraphVizSerializer {
template <typename T>
void SetVertexAttributes(const Node& node, T& attributes) {
attributes["label"] = GetNodeLabel(node);
attributes["style"] = "filled";
switch (node.type()) {
case Node::INSTRUCTION:
attributes["shape"] = "box";
attributes["style"] = "filled";
attributes["fillcolor"] = "#3c78d8";
attributes["fontcolor"] = "#ffffff";
break;
case Node::VARIABLE:
attributes["shape"] = "ellipse";
attributes["style"] = "filled";
attributes["fillcolor"] = "#f4cccc";
attributes["color"] = "#990000";
attributes["fontcolor"] = "#990000";
break;
case Node::CONSTANT:
attributes["shape"] = "diamond";
attributes["style"] = "filled";
attributes["shape"] = "octagon";
attributes["fillcolor"] = "#e99c9c";
attributes["color"] = "#990000";
attributes["fontcolor"] = "#990000";
break;
case Node::TYPE:
attributes["shape"] = "diamond";
attributes["fillcolor"] = "#cccccc";
attributes["color"] = "#cccccc";
attributes["fontcolor"] = "#222222";
break;
default:
LOG(FATAL) << "unreachable";
}
Expand All @@ -204,7 +208,7 @@ class GraphVizSerializer {
const Node& node = graph_.node(i);
// Determine the subgraph to add this node to.
boost::subgraph<GraphvizGraph>* dst = defaultGraph;
if (i && node.type() != Node::CONSTANT) {
if (i && (node.type() == Node::INSTRUCTION || node.type() == Node::VARIABLE)) {
dst = &(*functionGraphs)[node.function()].get();
}
auto vertex = add_vertex(i, *dst);
Expand All @@ -229,16 +233,22 @@ class GraphVizSerializer {
attributes["color"] = "#65ae4d";
attributes["weight"] = "1";
break;
case Edge::TYPE:
attributes["color"] = "#aaaaaa";
attributes["weight"] = "1";
attributes["penwidth"] = "1.5";
break;
default:
LOG(FATAL) << "unreachable";
}

// Set the edge label.
if (edge.position()) {
// Position labels for control edge are drawn close to the originating
// instruction. For data edges, they are drawn closer to the consuming
// instruction.
const string label = edge.flow() == Edge::DATA ? "headlabel" : "taillabel";
// instruction. For control edges, they are drawn close to the branching
// instruction. For data and type edges, they are drawn close to the
// consuming node.
const string label = edge.flow() == Edge::CONTROL ? "taillabel" : "headlabel";
attributes[label] = std::to_string(edge.position());
attributes["labelfontcolor"] = attributes["color"];
}
Expand Down
21 changes: 21 additions & 0 deletions programl/graph/program_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ Node* ProgramGraphBuilder::AddVariable(const string& text, const Function* funct

Node* ProgramGraphBuilder::AddConstant(const string& text) { return AddNode(Node::CONSTANT, text); }

Node* ProgramGraphBuilder::AddType(const string& text) { return AddNode(Node::TYPE, text); }

labm8::StatusOr<Edge*> ProgramGraphBuilder::AddControlEdge(int32_t position, const Node* source,
const Node* target) {
DCHECK(source) << "nullptr argument";
Expand Down Expand Up @@ -131,6 +133,25 @@ labm8::StatusOr<Edge*> ProgramGraphBuilder::AddCallEdge(const Node* source, cons
return AddEdge(Edge::CALL, /*position=*/0, source, target);
}

labm8::StatusOr<Edge*> ProgramGraphBuilder::AddTypeEdge(int32_t position, const Node* source,
const Node* target) {
DCHECK(source) << "nullptr argument";
DCHECK(target) << "nullptr argument";

if (source->type() != Node::TYPE) {
return Status(labm8::error::Code::INVALID_ARGUMENT,
"Invalid source type ({}) for type edge. Expected type",
Node::Type_Name(source->type()));
}
if (target->type() == Node::INSTRUCTION) {
return Status(labm8::error::Code::INVALID_ARGUMENT,
"Invalid destination type (instruction) for type edge. "
"Expected {variable,constant,type}");
}

return AddEdge(Edge::TYPE, position, source, target);
}

labm8::StatusOr<ProgramGraph> ProgramGraphBuilder::Build() {
if (options().strict()) {
RETURN_IF_ERROR(ValidateGraph());
Expand Down
7 changes: 6 additions & 1 deletion programl/graph/program_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class ProgramGraphBuilder {

Node* AddConstant(const string& text);

Node* AddType(const string& text);

// Edge factories.
[[nodiscard]] labm8::StatusOr<Edge*> AddControlEdge(int32_t position, const Node* source,
const Node* target);
Expand All @@ -73,6 +75,9 @@ class ProgramGraphBuilder {

[[nodiscard]] labm8::StatusOr<Edge*> AddCallEdge(const Node* source, const Node* target);

[[nodiscard]] labm8::StatusOr<Edge*> AddTypeEdge(int32_t position, const Node* source,
const Node* target);

const Node* GetRootNode() const { return &graph_.node(0); }

// Return the graph protocol buffer.
Expand Down Expand Up @@ -116,7 +121,7 @@ class ProgramGraphBuilder {
int32_t GetIndex(const Function* function);
int32_t GetIndex(const Node* node);

// Maps which covert store the index of objects in repeated field lists.
// Maps that store the index of objects in repeated field lists.
absl::flat_hash_map<Module*, int32_t> moduleIndices_;
absl::flat_hash_map<Function*, int32_t> functionIndices_;
absl::flat_hash_map<Node*, int32_t> nodeIndices_;
Expand Down
7 changes: 7 additions & 0 deletions programl/ir/llvm/inst2vec_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def Encode(self, proto: ProgramGraph, ir: Optional[str] = None) -> ProgramGraph:
# Add the node features.
var_embedding = self.dictionary["!IDENTIFIER"]
const_embedding = self.dictionary["!IMMEDIATE"]
type_embedding = self.dictionary["!IMMEDIATE"] # Types are immediates

text_index = 0
for node in proto.node:
Expand All @@ -113,6 +114,12 @@ def Encode(self, proto: ProgramGraph, ir: Optional[str] = None) -> ProgramGraph:
node.features.feature["inst2vec_embedding"].int64_list.value.append(
const_embedding
)
elif node.type == node_pb2.Node.TYPE:
node.features.feature["inst2vec_embedding"].int64_list.value.append(
type_embedding
)
else:
raise TypeError(f"Unknown node type {node}")

proto.features.feature["inst2vec_annotated"].int64_list.value.append(1)
return proto
Expand Down
119 changes: 116 additions & 3 deletions programl/ir/llvm/internal/program_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "labm8/cpp/logging.h"
#include "labm8/cpp/status_macros.h"
#include "labm8/cpp/string.h"
#include "llvm/IR/BasicBlock.h"
Expand Down Expand Up @@ -323,29 +324,131 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(const ::llvm::Instruction* instruc
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
const programl::Function* function) {
const LlvmTextComponents text = textEncoder_.Encode(operand);
Node* node = AddVariable(text.lhs_type, function);
Node* node = AddVariable("var", function);
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.lhs);

compositeTypeParts_.clear(); // Reset after previous call.
Node* type = GetOrCreateType(operand->getType());
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());

return node;
}

Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
const programl::Function* function) {
const LlvmTextComponents text = textEncoder_.Encode(argument);
Node* node = AddVariable(text.lhs_type, function);
Node* node = AddVariable("var", function);
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.lhs);

compositeTypeParts_.clear(); // Reset after previous call.
Node* type = GetOrCreateType(argument->getType());
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());

return node;
}

Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
const LlvmTextComponents text = textEncoder_.Encode(constant);
Node* node = AddConstant(text.lhs_type);
Node* node = AddConstant("val");
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.text);

compositeTypeParts_.clear(); // Reset after previous call.
Node* type = GetOrCreateType(constant->getType());
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());

return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::Type* type) {
// Dispatch to the type-specific handlers.
if (::llvm::dyn_cast<::llvm::StructType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::StructType>(type));
} else if (::llvm::dyn_cast<::llvm::PointerType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::PointerType>(type));
} else if (::llvm::dyn_cast<::llvm::FunctionType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::FunctionType>(type));
} else if (::llvm::dyn_cast<::llvm::ArrayType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::ArrayType>(type));
} else if (::llvm::dyn_cast<::llvm::VectorType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::VectorType>(type));
} else {
const LlvmTextComponents text = textEncoder_.Encode(type);
Node* node = AddType(text.text);
graph::AddScalarFeature(node, "full_text", text.text);
return node;
}
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::StructType* type) {
Node* node = AddType("struct");
compositeTypeParts_[type] = node;
graph::AddScalarFeature(node, "full_text", textEncoder_.Encode(type).text);

// Add types for the struct elements, and add type edges.
for (int i = 0; i < type->getNumElements(); ++i) {
const auto& member = type->elements()[i];
// Don't re-use member types in structs, always create a new type. For
// example, the code:
//
// struct S {
// int a;
// int b;
// };
// int c;
// int d;
//
// would produce four type nodes: one for S.a, one for S.b, and one which
// is shared by c and d.
Node* memberNode = AddLlvmType(member);
CHECK(AddTypeEdge(/*position=*/i, memberNode, node).ok());
}

return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::PointerType* type) {
Node* node = AddType("*");
graph::AddScalarFeature(node, "full_text", textEncoder_.Encode(type).text);

auto elementType = type->getElementType();
auto parent = compositeTypeParts_.find(elementType);
if (parent == compositeTypeParts_.end()) {
// Re-use the type if it already exists to prevent duplication.
auto elementNode = GetOrCreateType(type->getElementType());
CHECK(AddTypeEdge(/*position=*/0, elementNode, node).ok());
} else {
// Bottom-out for self-referencing types.
CHECK(AddTypeEdge(/*position=*/0, parent->second, node).ok());
}

return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::FunctionType* type) {
const std::string signature = textEncoder_.Encode(type).text;
Node* node = AddType(signature);
graph::AddScalarFeature(node, "full_text", signature);
return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::ArrayType* type) {
Node* node = AddType("[]");
graph::AddScalarFeature(node, "full_text", textEncoder_.Encode(type).text);
// Re-use the type if it already exists to prevent duplication.
auto elementType = GetOrCreateType(type->getElementType());
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::VectorType* type) {
Node* node = AddType("vector");
graph::AddScalarFeature(node, "full_text", textEncoder_.Encode(type).text);
// Re-use the type if it already exists to prevent duplication.
auto elementType = GetOrCreateType(type->getElementType());
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
return node;
}

Expand Down Expand Up @@ -461,6 +564,16 @@ void ProgramGraphBuilder::Clear() {
programl::graph::ProgramGraphBuilder::Clear();
}

Node* ProgramGraphBuilder::GetOrCreateType(const ::llvm::Type* type) {
auto it = types_.find(type);
if (it == types_.end()) {
Node* node = AddLlvmType(type);
types_[type] = node;
return node;
}
return it->second;
}

} // namespace internal
} // namespace llvm
} // namespace ir
Expand Down
32 changes: 32 additions & 0 deletions programl/ir/llvm/internal/program_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {

void Clear();

// Return the node representing a type. If no node already exists
// for this type, a new node is created and added to the graph. In
// the case of composite types, multiple new nodes may be added by
// this call, and the root type returned.
Node* GetOrCreateType(const ::llvm::Type* type);

protected:
[[nodiscard]] labm8::StatusOr<FunctionEntryExits> VisitFunction(const ::llvm::Function& function,
const Function* functionMessage);
Expand All @@ -85,6 +91,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
Node* AddLlvmVariable(const ::llvm::Instruction* operand, const Function* function);
Node* AddLlvmVariable(const ::llvm::Argument* argument, const Function* function);
Node* AddLlvmConstant(const ::llvm::Constant* constant);
Node* AddLlvmType(const ::llvm::Type* type);
Node* AddLlvmType(const ::llvm::StructType* type);
Node* AddLlvmType(const ::llvm::PointerType* type);
Node* AddLlvmType(const ::llvm::FunctionType* type);
Node* AddLlvmType(const ::llvm::ArrayType* type);
Node* AddLlvmType(const ::llvm::VectorType* type);

private:
TextEncoder textEncoder_;
Expand All @@ -99,6 +111,26 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
// populated by VisitBasicBlock() and consumed once all functions have been
// visited.
absl::flat_hash_map<const ::llvm::Constant*, std::vector<PositionalNode>> constants_;

// A map from an LLVM type to the node message that represents it.
absl::flat_hash_map<const ::llvm::Type*, Node*> types_;

// When adding a new type to the graph we need to know whether the type that
// we are adding is part of a composite type that references itself. For
// example:
//
// struct BinaryTree {
// int data;
// struct BinaryTree* left;
// struct BinaryTree* right;
// }
//
// When the recursive GetOrCreateType() resolves the "left" member, it needs
// to know that the parent BinaryTree type has already been processed. This
// map stores the Nodes corresponding to any parent structs that have been
// already added in a call to GetOrCreateType(). It must be cleared between
// calls.
absl::flat_hash_map<const ::llvm::Type*, Node*> compositeTypeParts_;
};

} // namespace internal
Expand Down
Loading