Skip to content

Commit ef202dc

Browse files
authored
Fix Callgraph by filtering out impossible virtual call targets (#586)
1 parent a1a1a47 commit ef202dc

File tree

3 files changed

+105
-86
lines changed

3 files changed

+105
-86
lines changed

include/phasar/PhasarLLVM/ControlFlow/Resolver/Resolver.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,16 @@ enum class CallGraphAnalysisType;
3939
class LLVMBasedICFG;
4040
class LLVMPointsToInfo;
4141

42-
std::optional<unsigned> getVFTIndex(const llvm::CallBase *CallSite);
42+
[[nodiscard]] std::optional<unsigned>
43+
getVFTIndex(const llvm::CallBase *CallSite);
4344

44-
const llvm::StructType *getReceiverType(const llvm::CallBase *CallSite);
45+
[[nodiscard]] const llvm::StructType *
46+
getReceiverType(const llvm::CallBase *CallSite);
4547

46-
std::string getReceiverTypeName(const llvm::CallBase &CallSite);
48+
[[nodiscard]] std::string getReceiverTypeName(const llvm::CallBase *CallSite);
49+
50+
[[nodiscard]] bool isConsistentCall(const llvm::CallBase *CallSite,
51+
const llvm::Function *DestFun);
4752

4853
class Resolver {
4954
protected:

lib/PhasarLLVM/ControlFlow/Resolver/OTFResolver.cpp

Lines changed: 74 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "phasar/PhasarLLVM/ControlFlow/Resolver/OTFResolver.h"
1111

1212
#include "phasar/PhasarLLVM/ControlFlow/LLVMBasedICFG.h"
13+
#include "phasar/PhasarLLVM/ControlFlow/Resolver/Resolver.h"
1314
#include "phasar/PhasarLLVM/TypeHierarchy/LLVMTypeHierarchy.h"
1415
#include "phasar/PhasarLLVM/Utils/LLVMShorthands.h"
1516
#include "phasar/Utils/Logger.h"
@@ -114,7 +115,8 @@ auto OTFResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
114115
}
115116
const auto *Callee = VFs[VtableIndex];
116117
if (Callee == nullptr || !Callee->hasName() ||
117-
Callee->getName() == LLVMTypeHierarchy::PureVirtualCallName) {
118+
Callee->getName() == LLVMTypeHierarchy::PureVirtualCallName ||
119+
!isConsistentCall(CallSite, Callee)) {
118120
continue;
119121
}
120122
PossibleCallTargets.insert(Callee);
@@ -130,96 +132,93 @@ auto OTFResolver::resolveVirtualCall(const llvm::CallBase *CallSite)
130132

131133
auto OTFResolver::resolveFunctionPointer(const llvm::CallBase *CallSite)
132134
-> FunctionSetTy {
135+
if (!CallSite->getCalledOperand()) {
136+
return {};
137+
}
138+
133139
FunctionSetTy Callees;
134-
if (CallSite->getCalledOperand() &&
135-
CallSite->getCalledOperand()->getType()->isPointerTy()) {
136-
if (const auto *FTy = llvm::dyn_cast<llvm::FunctionType>(
137-
CallSite->getCalledOperand()->getType()->getPointerElementType())) {
138140

139-
auto PTS = PT.getAliasSet(CallSite->getCalledOperand(), CallSite);
141+
auto PTS = PT.getAliasSet(CallSite->getCalledOperand(), CallSite);
140142

141-
llvm::SmallVector<const llvm::GlobalVariable *, 2> GlobalVariableWL;
142-
llvm::SmallVector<const llvm::ConstantAggregate *> ConstantAggregateWL;
143-
llvm::SmallPtrSet<const llvm::ConstantAggregate *, 4>
144-
VisitedConstantAggregates;
143+
llvm::SmallVector<const llvm::GlobalVariable *, 2> GlobalVariableWL;
144+
llvm::SmallVector<const llvm::ConstantAggregate *> ConstantAggregateWL;
145+
llvm::SmallPtrSet<const llvm::ConstantAggregate *, 4>
146+
VisitedConstantAggregates;
145147

146-
for (const auto *P : *PTS) {
147-
if (!llvm::isa<llvm::Constant>(P)) {
148-
continue;
149-
}
148+
for (const auto *P : *PTS) {
149+
if (!llvm::isa<llvm::Constant>(P)) {
150+
continue;
151+
}
150152

151-
GlobalVariableWL.clear();
152-
ConstantAggregateWL.clear();
153+
GlobalVariableWL.clear();
154+
ConstantAggregateWL.clear();
153155

154-
if (P->getType()->isPointerTy() &&
155-
P->getType()->getPointerElementType()->isFunctionTy()) {
156-
if (const auto *F = llvm::dyn_cast<llvm::Function>(P)) {
157-
if (matchesSignature(F, FTy, false)) {
158-
Callees.insert(F);
159-
}
160-
}
156+
if (P->getType()->isPointerTy() &&
157+
P->getType()->getPointerElementType()->isFunctionTy()) {
158+
if (const auto *F = llvm::dyn_cast<llvm::Function>(P)) {
159+
if (isConsistentCall(CallSite, F)) {
160+
Callees.insert(F);
161161
}
162+
}
163+
}
162164

163-
if (const auto *GVP = llvm::dyn_cast<llvm::GlobalVariable>(P)) {
164-
GlobalVariableWL.push_back(GVP);
165-
} else if (const auto *CE = llvm::dyn_cast<llvm::ConstantExpr>(P)) {
166-
for (const auto &Op : CE->operands()) {
167-
if (const auto *GVOp = llvm::dyn_cast<llvm::GlobalVariable>(Op)) {
168-
GlobalVariableWL.push_back(GVOp);
169-
}
170-
}
165+
if (const auto *GVP = llvm::dyn_cast<llvm::GlobalVariable>(P)) {
166+
GlobalVariableWL.push_back(GVP);
167+
} else if (const auto *CE = llvm::dyn_cast<llvm::ConstantExpr>(P)) {
168+
for (const auto &Op : CE->operands()) {
169+
if (const auto *GVOp = llvm::dyn_cast<llvm::GlobalVariable>(Op)) {
170+
GlobalVariableWL.push_back(GVOp);
171171
}
172+
}
173+
}
172174

173-
if (GlobalVariableWL.empty()) {
174-
continue;
175-
}
175+
if (GlobalVariableWL.empty()) {
176+
continue;
177+
}
176178

177-
for (const auto *GV : GlobalVariableWL) {
178-
if (!GV->hasInitializer()) {
179-
continue;
180-
}
181-
const auto *InitConst = GV->getInitializer();
182-
if (const auto *InitConstAggregate =
183-
llvm::dyn_cast<llvm::ConstantAggregate>(InitConst)) {
184-
ConstantAggregateWL.push_back(InitConstAggregate);
179+
for (const auto *GV : GlobalVariableWL) {
180+
if (!GV->hasInitializer()) {
181+
continue;
182+
}
183+
const auto *InitConst = GV->getInitializer();
184+
if (const auto *InitConstAggregate =
185+
llvm::dyn_cast<llvm::ConstantAggregate>(InitConst)) {
186+
ConstantAggregateWL.push_back(InitConstAggregate);
187+
}
188+
}
189+
190+
VisitedConstantAggregates.clear();
191+
192+
while (!ConstantAggregateWL.empty()) {
193+
const auto *ConstAggregateItem = ConstantAggregateWL.pop_back_val();
194+
// We may have already processed the item, avoid an infinite loop
195+
if (!VisitedConstantAggregates.insert(ConstAggregateItem).second) {
196+
continue;
197+
}
198+
for (const auto &Op : ConstAggregateItem->operands()) {
199+
if (const auto *CE = llvm::dyn_cast<llvm::ConstantExpr>(Op)) {
200+
if (CE->getType()->isPointerTy() && CE->isCast()) {
201+
if (const auto *F =
202+
llvm::dyn_cast<llvm::Function>(CE->getOperand(0));
203+
F && isConsistentCall(CallSite, F)) {
204+
Callees.insert(F);
205+
}
185206
}
186207
}
187208

188-
VisitedConstantAggregates.clear();
189-
190-
while (!ConstantAggregateWL.empty()) {
191-
const auto *ConstAggregateItem = ConstantAggregateWL.pop_back_val();
192-
// We may have already processed the item, avoid an infinite loop
193-
if (!VisitedConstantAggregates.insert(ConstAggregateItem).second) {
209+
if (const auto *F = llvm::dyn_cast<llvm::Function>(Op)) {
210+
if (isConsistentCall(CallSite, F)) {
211+
Callees.insert(F);
212+
}
213+
} else if (auto *CA = llvm::dyn_cast<llvm::ConstantAggregate>(Op)) {
214+
ConstantAggregateWL.push_back(CA);
215+
} else if (auto *GV = llvm::dyn_cast<llvm::GlobalVariable>(Op)) {
216+
if (!GV->hasInitializer()) {
194217
continue;
195218
}
196-
for (const auto &Op : ConstAggregateItem->operands()) {
197-
if (const auto *CE = llvm::dyn_cast<llvm::ConstantExpr>(Op)) {
198-
if (CE->getType()->isPointerTy() &&
199-
CE->getType()->getPointerElementType() == FTy &&
200-
CE->isCast()) {
201-
if (const auto *F =
202-
llvm::dyn_cast<llvm::Function>(CE->getOperand(0))) {
203-
Callees.insert(F);
204-
}
205-
}
206-
}
207-
208-
if (const auto *F = llvm::dyn_cast<llvm::Function>(Op)) {
209-
if (matchesSignature(F, FTy, false)) {
210-
Callees.insert(F);
211-
}
212-
} else if (auto *CA = llvm::dyn_cast<llvm::ConstantAggregate>(Op)) {
213-
ConstantAggregateWL.push_back(CA);
214-
} else if (auto *GV = llvm::dyn_cast<llvm::GlobalVariable>(Op)) {
215-
if (!GV->hasInitializer()) {
216-
continue;
217-
}
218-
if (auto *GVCA = llvm::dyn_cast<llvm::ConstantAggregate>(
219-
GV->getInitializer())) {
220-
ConstantAggregateWL.push_back(GVCA);
221-
}
222-
}
219+
if (auto *GVCA = llvm::dyn_cast<llvm::ConstantAggregate>(
220+
GV->getInitializer())) {
221+
ConstantAggregateWL.push_back(GVCA);
223222
}
224223
}
225224
}

lib/PhasarLLVM/ControlFlow/Resolver/Resolver.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@
3838
#include <optional>
3939
#include <set>
4040

41-
namespace psr {
42-
43-
std::optional<unsigned> getVFTIndex(const llvm::CallBase *CallSite) {
41+
std::optional<unsigned> psr::getVFTIndex(const llvm::CallBase *CallSite) {
4442
// deal with a virtual member function
4543
// retrieve the vtable entry that is called
4644
const auto *Load =
@@ -59,7 +57,7 @@ std::optional<unsigned> getVFTIndex(const llvm::CallBase *CallSite) {
5957
return std::nullopt;
6058
}
6159

62-
const llvm::StructType *getReceiverType(const llvm::CallBase *CallSite) {
60+
const llvm::StructType *psr::getReceiverType(const llvm::CallBase *CallSite) {
6361
if (CallSite->arg_empty() ||
6462
(CallSite->hasStructRetAttr() && CallSite->arg_size() < 2)) {
6563
return nullptr;
@@ -86,25 +84,42 @@ const llvm::StructType *getReceiverType(const llvm::CallBase *CallSite) {
8684
return nullptr;
8785
}
8886

89-
std::string getReceiverTypeName(const llvm::CallBase *CallSite) {
87+
std::string psr::getReceiverTypeName(const llvm::CallBase *CallSite) {
9088
const auto *RT = getReceiverType(CallSite);
9189
if (RT) {
9290
return RT->getName().str();
9391
}
9492
return "";
9593
}
9694

95+
bool psr::isConsistentCall(const llvm::CallBase *CallSite,
96+
const llvm::Function *DestFun) {
97+
if (CallSite->arg_size() < DestFun->arg_size()) {
98+
return false;
99+
}
100+
if (CallSite->arg_size() != DestFun->arg_size() && !DestFun->isVarArg()) {
101+
return false;
102+
}
103+
if (!matchesSignature(DestFun, CallSite->getFunctionType(), false)) {
104+
return false;
105+
}
106+
return true;
107+
}
108+
109+
namespace psr {
110+
97111
Resolver::Resolver(LLVMProjectIRDB &IRDB) : IRDB(IRDB), TH(nullptr) {}
98112

99113
Resolver::Resolver(LLVMProjectIRDB &IRDB, LLVMTypeHierarchy &TH)
100114
: IRDB(IRDB), TH(&TH) {}
101115

102116
const llvm::Function *
103117
Resolver::getNonPureVirtualVFTEntry(const llvm::StructType *T, unsigned Idx,
104-
const llvm::CallBase * /*CallSite*/) {
118+
const llvm::CallBase *CallSite) {
105119
if (TH && TH->hasVFTable(T)) {
106120
const auto *Target = TH->getVFTable(T)->getFunction(Idx);
107-
if (Target && Target->getName() != "__cxa_pure_virtual") {
121+
if (Target && Target->getName() != LLVMTypeHierarchy::PureVirtualCallName &&
122+
isConsistentCall(CallSite, Target)) {
108123
return Target;
109124
}
110125
}
@@ -133,7 +148,7 @@ auto Resolver::resolveFunctionPointer(const llvm::CallBase *CallSite)
133148
if (const auto *FTy = llvm::dyn_cast<llvm::FunctionType>(
134149
CallSite->getCalledOperand()->getType()->getPointerElementType())) {
135150
for (const auto *F : IRDB.getAllFunctions()) {
136-
if (matchesSignature(F, FTy)) {
151+
if (isConsistentCall(CallSite, F)) {
137152
CalleeTargets.insert(F);
138153
}
139154
}

0 commit comments

Comments
 (0)