Skip to content

[ossdataflowengine] code clean up and remove redundant check #2736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.joern.dataflowengineoss.queryengine
import io.joern.dataflowengineoss.queryengine.QueryEngineStatistics.{PATH_CACHE_HITS, PATH_CACHE_MISSES}
import io.joern.dataflowengineoss.semanticsloader.Semantics
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.semanticcpg.language.{toCfgNodeMethods, toExpressionMethods, _}
import io.shiftleft.semanticcpg.language._

import java.util.concurrent.Callable
import scala.collection.mutable
Expand Down Expand Up @@ -33,7 +33,7 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg
val table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]] = mutable.Map()
results(task.sink, path, table, task.callSiteStack)
// TODO why do we update the call depth here?
val finalResults = table.get(task.fingerprint).get.map { r =>
val finalResults = table(task.fingerprint).map { r =>
r.copy(
taskStack = r.taskStack.dropRight(1) :+ r.fingerprint.copy(callDepth = task.callDepth),
path = r.path ++ task.initialPath
Expand Down Expand Up @@ -68,20 +68,20 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg
* @param callSiteStack
* This stack holds all call sites we expanded to arrive at the generation of the current task
*/
private def results[NodeType <: CfgNode](
private def results(
sink: CfgNode,
path: Vector[PathElement],
table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]],
callSiteStack: List[Call]
)(implicit semantics: Semantics): Vector[ReachableByResult] = {

val curNode = path.head.node
val curNode = path.head.node.asInstanceOf[CfgNode]

/** For each parent of the current node, determined via `expandIn`, check if results are available in the result
* table. If not, determine results recursively.
*/
def computeResultsForParents() = {
deduplicateWithinTask(expandIn(curNode.asInstanceOf[CfgNode], path, callSiteStack).iterator.flatMap { parent =>
deduplicateWithinTask(expandIn(curNode, path, callSiteStack).iterator.flatMap { parent =>
createResultsFromCacheOrCompute(parent, path)
}.toVector)
}
Expand Down Expand Up @@ -117,14 +117,14 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg
}

def createResultsFromCacheOrCompute(elemToPrepend: PathElement, path: Vector[PathElement]) = {
val cachedResult = createFromTable(table, elemToPrepend, task.callSiteStack, path, task.callDepth)
if (cachedResult.isDefined) {
QueryEngineStatistics.incrementBy(PATH_CACHE_HITS, 1L)
cachedResult.get
} else {
QueryEngineStatistics.incrementBy(PATH_CACHE_MISSES, 1L)
val newPath = elemToPrepend +: path
results(sink, newPath, table, callSiteStack)
createFromTable(table, elemToPrepend, task.callSiteStack, path, task.callDepth) match {
case Some(result) =>
QueryEngineStatistics.incrementBy(PATH_CACHE_HITS, 1L)
result
case None =>
QueryEngineStatistics.incrementBy(PATH_CACHE_MISSES, 1L)
val newPath = elemToPrepend +: path
results(sink, newPath, table, callSiteStack)
}
}

Expand Down Expand Up @@ -163,7 +163,7 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg
*/
val res = curNode match {
// Case 1: we have reached a source => return result and continue traversing (expand into parents)
case x if sources.contains(x.asInstanceOf[NodeType]) =>
case x if sources.contains(x) =>
if (x.isInstanceOf[MethodParameterIn]) {
Vector(
ReachableByResult(task.taskStack, path),
Expand All @@ -172,43 +172,32 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg
} else {
Vector(ReachableByResult(task.taskStack, path)) ++ computeResultsForParents()
}
// Case 2: we have reached a method parameter (that isn't a source) => return partial result and stop traversing
// Case 2: we have reached a method parameter (that isn't a source)
// => return partial result and stop traversing
case _: MethodParameterIn =>
Vector(ReachableByResult(task.taskStack, path, partial = true))
// Case 3: we have reached a call to an internal method without semantic (return value) and
// this isn't the start node => return partial result and stop traversing
case call: Call
if isCallToInternalMethodWithoutSemantic(call)
&& !isArgOrRetOfMethodWeCameFrom(call, path) =>
// Case 3: we have reached a call to an internal method without semantic (return value)
// => return partial result and stop traversing
case call: Call if isCallToInternalMethodWithoutSemantic(call) =>
createPartialResultForOutputArgOrRet()

// Case 4: we have reached an argument to an internal method without semantic (output argument) and
// this isn't the start node nor is it the argument for the parameter we just expanded => return partial result and stop traversing
// Case 4: we have reached an argument to an internal method without semantic (output argument) and this isn't the start node
// => return partial result and stop traversing
case arg: Expression
if path.size > 1
&& arg.inCall.toList.exists(c => isCallToInternalMethodWithoutSemantic(c))
&& !arg.inCall.headOption.exists(x => isArgOrRetOfMethodWeCameFrom(x, path)) =>
&& arg.inCall.toList.exists(c => isCallToInternalMethodWithoutSemantic(c)) =>
createPartialResultForOutputArgOrRet()

case _: MethodRef => createPartialResultForOutputArgOrRet()

// All other cases: expand into parents
case _ =>
computeResultsForParents()
case _ => computeResultsForParents()
}
val key = TaskFingerprint(curNode.asInstanceOf[CfgNode], task.callSiteStack, task.callDepth)
val key = TaskFingerprint(curNode, task.callSiteStack, task.callDepth)
table.updateWith(key) {
case Some(existingValue) => Some(existingValue ++ res)
case None => Some(res)
}
res
}

private def isArgOrRetOfMethodWeCameFrom(call: Call, path: Vector[PathElement]): Boolean =
path match {
case Vector(_, PathElement(x: MethodReturn, _, _, _, _), _*) => methodsForCall(call).contains(x.method)
case Vector(_, PathElement(x: MethodParameterIn, _, _, _, _), _*) => methodsForCall(call).contains(x.method)
case _ => false
}

}