Skip to content

Commit 3709e9a

Browse files
committed
llvm: Deduplicate LLVM-IR strings.
This changes the format of the LLVM-IR program graphs to store a list of unique strings, rather than LLVM-IR strings in each node. We use a graph-level "strings" feature to store a list of the original LLVM-IR string corresponding to each graph nodes. This allows to us to refer to the same string from multiple nodes without duplication. This breaks compatability with the inst2vec encoder on program graphs generated prior to this commit. Signed-off-by: format 2020.06.15 <github.com/ChrisCummins/format>
1 parent a2caa52 commit 3709e9a

File tree

6 files changed

+80
-22
lines changed

6 files changed

+80
-22
lines changed

programl/graph/program_graph_builder.h

+3
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class ProgramGraphBuilder {
103103
inline Edge* AddEdge(const Edge::Flow& flow, int32_t position, const Node* source,
104104
const Node* target);
105105

106+
// Return a mutable pointer to the root node in the graph.
107+
Node* GetMutableRootNode() { return graph_.mutable_node(0); }
108+
106109
// Return a mutable pointer to the graph protocol buffer.
107110
ProgramGraph* GetMutableProgramGraph() { return &graph_; }
108111

programl/ir/llvm/inst2vec_encoder.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,10 @@
4343
)
4444

4545

46-
def NodeFullText(node: node_pb2.Node) -> str:
46+
def NodeFullText(graph: program_graph_pb2.ProgramGraph, node: node_pb2.Node) -> str:
4747
"""Get the full text of a node, or an empty string if not set."""
48-
if len(node.features.feature["full_text"].bytes_list.value):
49-
return node.features.feature["full_text"].bytes_list.value[0].decode("utf-8")
50-
return ""
48+
idx = node.features.feature["llvm_string"].int64_list.value[0]
49+
return graph.features.feature["strings"].bytes_list.value[idx].decode("utf-8")
5150

5251

5352
class Inst2vecEncoder(object):
@@ -85,7 +84,7 @@ def Encode(
8584
"""
8685
# Gather the instruction texts to pre-process.
8786
lines = [
88-
[NodeFullText(node)]
87+
[NodeFullText(proto, node)]
8988
for node in proto.node
9089
if node.type == node_pb2.Node.INSTRUCTION
9190
]

programl/ir/llvm/inst2vec_encoder_test.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,21 @@ def AddVariable(self, full_text: str):
5555

5656
def Build(self):
5757
proto = super(Inst2vecGraphBuilder, self).Build()
58+
59+
# Add the root node string feature.
60+
proto.node[0].features.feature["llvm_string"].int64_list.value[:] = [0]
61+
62+
# Build the strings list.
63+
strings_list = list(set(self.full_texts.values()))
64+
proto.features.feature["strings"].bytes_list.value[:] = [
65+
string.encode("utf-8") for string in strings_list
66+
]
67+
68+
# Add the string indices.
5869
for node, full_text in self.full_texts.items():
59-
proto.node[node].features.feature["full_text"].bytes_list.value.append(
60-
full_text.encode("utf-8")
61-
)
70+
idx = strings_list.index(full_text)
71+
node_feature = proto.node[node].features.feature["llvm_string"]
72+
node_feature.int64_list.value.append(idx)
6273
return proto
6374

6475

programl/ir/llvm/internal/program_graph_builder.cc

+36-5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ namespace ir {
3939
namespace llvm {
4040
namespace internal {
4141

42+
ProgramGraphBuilder::ProgramGraphBuilder(const ProgramGraphOptions& options)
43+
: programl::graph::ProgramGraphBuilder(),
44+
options_(options),
45+
blockCount_(0),
46+
stringsList_((*GetMutableProgramGraph()->mutable_features()->mutable_feature())["strings"]
47+
.mutable_bytes_list()) {
48+
// Add an empty
49+
graph::AddScalarFeature(GetMutableRootNode(), "llvm_string", AddString(""));
50+
}
51+
4252
labm8::StatusOr<BasicBlockEntryExit> ProgramGraphBuilder::VisitBasicBlock(
4353
const ::llvm::BasicBlock& block, const Function* functionMessage, InstructionMap* instructions,
4454
ArgumentConsumerMap* argumentConsumers, std::vector<DataEdge>* dataEdgesToAdd) {
@@ -184,7 +194,7 @@ labm8::StatusOr<FunctionEntryExits> ProgramGraphBuilder::VisitFunction(
184194

185195
if (function.isDeclaration()) {
186196
Node* node = AddInstruction("; undefined function", functionMessage);
187-
graph::AddScalarFeature(node, "full_text", "");
197+
graph::AddScalarFeature(node, "llvm_string", AddString(""));
188198
functionEntryExits.first = node;
189199
functionEntryExits.second.push_back(node);
190200
return functionEntryExits;
@@ -305,7 +315,7 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(const ::llvm::Instruction* instruc
305315
const LlvmTextComponents text = textEncoder_.Encode(instruction);
306316
Node* node = AddInstruction(text.opcode_name, function);
307317
node->set_block(blockCount_);
308-
graph::AddScalarFeature(node, "full_text", text.text);
318+
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
309319

310320
// Add profiling information features, if available.
311321
uint64_t profTotalWeight;
@@ -327,7 +337,7 @@ Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
327337
const LlvmTextComponents text = textEncoder_.Encode(operand);
328338
Node* node = AddVariable(text.lhs_type, function);
329339
node->set_block(blockCount_);
330-
graph::AddScalarFeature(node, "full_text", text.lhs);
340+
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
331341

332342
return node;
333343
}
@@ -337,7 +347,7 @@ Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
337347
const LlvmTextComponents text = textEncoder_.Encode(argument);
338348
Node* node = AddVariable(text.lhs_type, function);
339349
node->set_block(blockCount_);
340-
graph::AddScalarFeature(node, "full_text", text.lhs);
350+
graph::AddScalarFeature(node, "llvm_string", AddString(text.lhs));
341351

342352
return node;
343353
}
@@ -346,7 +356,7 @@ Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
346356
const LlvmTextComponents text = textEncoder_.Encode(constant);
347357
Node* node = AddConstant(text.lhs_type);
348358
node->set_block(blockCount_);
349-
graph::AddScalarFeature(node, "full_text", text.text);
359+
graph::AddScalarFeature(node, "llvm_string", AddString(text.text));
350360

351361
return node;
352362
}
@@ -436,6 +446,27 @@ void ProgramGraphBuilder::Clear() {
436446
programl::graph::ProgramGraphBuilder::Clear();
437447
}
438448

449+
Node* ProgramGraphBuilder::GetOrCreateType(const ::llvm::Type* type) {
450+
auto it = types_.find(type);
451+
if (it == types_.end()) {
452+
Node* node = AddLlvmType(type);
453+
types_[type] = node;
454+
return node;
455+
}
456+
return it->second;
457+
}
458+
459+
int32_t ProgramGraphBuilder::AddString(const string& text) {
460+
auto it = stringsListPositions_.find(text);
461+
if (it == stringsListPositions_.end()) {
462+
int32_t index = stringsListPositions_.size();
463+
stringsListPositions_[text] = index;
464+
stringsList_->add_value(text);
465+
return index;
466+
}
467+
return it->second;
468+
}
469+
439470
} // namespace internal
440471
} // namespace llvm
441472
} // namespace ir

programl/ir/llvm/internal/program_graph_builder.h

+14-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ using ArgumentConsumerMap =
6464
// A specialized program graph builder for LLVM-IR.
6565
class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
6666
public:
67-
explicit ProgramGraphBuilder(const ProgramGraphOptions& options)
68-
: programl::graph::ProgramGraphBuilder(options), blockCount_(0) {}
67+
explicit ProgramGraphBuilder(const ProgramGraphOptions& options);
6968

7069
[[nodiscard]] labm8::StatusOr<ProgramGraph> Build(const ::llvm::Module& module);
7170

@@ -87,6 +86,13 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
8786
Node* AddLlvmVariable(const ::llvm::Argument* argument, const Function* function);
8887
Node* AddLlvmConstant(const ::llvm::Constant* constant);
8988

89+
// Add a string to the strings list and return its position.
90+
//
91+
// We use a graph-level "strings" feature to store a list of the original
92+
// LLVM-IR string corresponding to each graph nodes. This allows to us to
93+
// refer to the same string from multiple nodes without duplication.
94+
int32_t AddString(const string& text);
95+
9096
private:
9197
TextEncoder textEncoder_;
9298

@@ -100,6 +106,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
100106
// populated by VisitBasicBlock() and consumed once all functions have been
101107
// visited.
102108
absl::flat_hash_map<const ::llvm::Constant*, std::vector<PositionalNode>> constants_;
109+
110+
// A mapping from string table value to its position in the "strings_table"
111+
// graph-level feature.
112+
absl::flat_hash_map<string, int32_t> stringsListPositions_;
113+
// The underlying storage for the strings table.
114+
BytesList* stringsList_;
103115
};
104116

105117
} // namespace internal

programl/ir/llvm/py/llvm_test.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@
3737
"""
3838

3939

40-
def GetStringScalar(proto, name):
41-
return proto.features.feature[name].bytes_list.value[0].decode("utf-8")
40+
def NodeFullText(graph: program_graph_pb2.ProgramGraph, node: node_pb2.Node) -> str:
41+
"""Get the full text of a node, or an empty string if not set."""
42+
idx = node.features.feature["llvm_string"].int64_list.value[0]
43+
return graph.features.feature["strings"].bytes_list.value[idx].decode("utf-8")
4244

4345

4446
def test_simple_ir():
@@ -56,25 +58,25 @@ def test_simple_ir():
5658

5759
assert proto.node[1].text == "add"
5860
assert proto.node[1].type == node_pb2.Node.INSTRUCTION
59-
assert GetStringScalar(proto.node[1], "full_text") == "%3 = add nsw i32 %1, %0"
61+
assert NodeFullText(proto, proto.node[1]) == "%3 = add nsw i32 %1, %0"
6062

6163
assert proto.node[2].text == "ret"
6264
assert proto.node[2].type == node_pb2.Node.INSTRUCTION
63-
assert GetStringScalar(proto.node[2], "full_text") == "ret i32 %3"
65+
assert NodeFullText(proto, proto.node[2]) == "ret i32 %3"
6466

6567
assert proto.node[3].text == "i32"
6668
assert proto.node[3].type == node_pb2.Node.VARIABLE
67-
assert GetStringScalar(proto.node[3], "full_text") == "i32 %3"
69+
assert NodeFullText(proto, proto.node[3]) == "i32 %3"
6870

6971
# Use startswith() to compare names for these last two variables as thier
7072
# order may differ.
7173
assert proto.node[4].text == "i32"
7274
assert proto.node[4].type == node_pb2.Node.VARIABLE
73-
assert GetStringScalar(proto.node[4], "full_text").startswith("i32 %")
75+
assert NodeFullText(proto, proto.node[4]).startswith("i32 %")
7476

7577
assert proto.node[5].text == "i32"
7678
assert proto.node[5].type == node_pb2.Node.VARIABLE
77-
assert GetStringScalar(proto.node[5], "full_text").startswith("i32 %")
79+
assert NodeFullText(proto, proto.node[5]).startswith("i32 %")
7880

7981

8082
def test_opt_level():

0 commit comments

Comments
 (0)