Skip to content

Commit c5c8ef5

Browse files
[ruby] Lower Array Creation for In Pattern Match (#5229)
Pattern match array creation is now lowered so that both reads and writes assign array indices to where each element much be read and written from.
1 parent 48d26b5 commit c5c8ef5

File tree

4 files changed

+77
-25
lines changed

4 files changed

+77
-25
lines changed

joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package io.joern.rubysrc2cpg.astcreation
22

33
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
4+
ArrayLiteral,
45
ArrayPattern,
56
BinaryExpression,
67
BreakExpression,
78
CaseExpression,
89
ControlFlowStatement,
910
DoWhileExpression,
11+
DummyAst,
12+
DynamicLiteral,
1013
ElseClause,
1114
ForExpression,
1215
IfExpression,
@@ -17,10 +20,9 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
1720
NextExpression,
1821
OperatorAssignment,
1922
RescueExpression,
20-
ReturnExpression,
2123
RubyExpression,
22-
SimpleCall,
2324
SimpleIdentifier,
25+
SimpleObjectInstantiation,
2426
SingleAssignment,
2527
SplattingRubyNode,
2628
StatementList,
@@ -32,18 +34,12 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
3234
WhenClause,
3335
WhileExpression
3436
}
35-
import io.joern.rubysrc2cpg.parser.RubyJsonHelpers
37+
import io.joern.rubysrc2cpg.datastructures.BlockScope
3638
import io.joern.rubysrc2cpg.passes.Defines
37-
import io.joern.rubysrc2cpg.passes.Defines.RubyOperators
39+
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
3840
import io.joern.x2cpg.{Ast, ValidationMode}
41+
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewFieldIdentifier, NewLiteral, NewLocal}
3942
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}
40-
import io.shiftleft.codepropertygraph.generated.nodes.{
41-
NewBlock,
42-
NewFieldIdentifier,
43-
NewIdentifier,
44-
NewLiteral,
45-
NewLocal
46-
}
4743

4844
trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>
4945

@@ -335,16 +331,67 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo
335331
ifElseChain.iterator.toList
336332
}
337333

338-
def generatedNode: StatementList = node.expression
339-
.map { e =>
340-
val tmp = SimpleIdentifier(None)(e.span.spanStart(this.tmpGen.fresh))
341-
StatementList(
342-
List(SingleAssignment(tmp, "=", e)(e.span)) ++
343-
goCase(Some(tmp))
344-
)(node.span)
334+
val caseExpr = node.expression
335+
.map {
336+
case arrayLiteral: ArrayLiteral =>
337+
val tmp = SimpleIdentifier(None)(arrayLiteral.span.spanStart(this.tmpGen.fresh))
338+
val arrayLiteralAst = DummyAst(astForTempArray(arrayLiteral))(arrayLiteral.span)
339+
(tmp, arrayLiteralAst)
340+
case e =>
341+
val tmp = SimpleIdentifier(None)(e.span.spanStart(this.tmpGen.fresh))
342+
(tmp, e)
345343
}
344+
.map((tmp, e) => StatementList(List(SingleAssignment(tmp, "=", e)(e.span)) ++ goCase(Some(tmp)))(node.span))
346345
.getOrElse(StatementList(goCase(None))(node.span))
347-
astsForStatement(generatedNode)
346+
347+
astsForStatement(caseExpr)
348+
}
349+
350+
private def astForTempArray(node: ArrayLiteral): Ast = {
351+
val tmp = this.tmpGen.fresh
352+
353+
def tmpRubyNode(tmpNode: Option[RubyExpression] = None) =
354+
SimpleIdentifier()(tmpNode.map(_.span).getOrElse(node.span).spanStart(tmp))
355+
356+
def tmpAst(tmpNode: Option[RubyExpression] = None) = astForSimpleIdentifier(tmpRubyNode(tmpNode))
357+
358+
val block = blockNode(node, node.text, Defines.Any)
359+
scope.pushNewScope(BlockScope(block))
360+
val tmpLocal = NewLocal().name(tmp).code(tmp)
361+
scope.addToScope(tmp, tmpLocal)
362+
363+
val arguments = if (node.text.startsWith("%")) {
364+
val argumentsType =
365+
if (node.isStringArray) getBuiltInType(Defines.String)
366+
else getBuiltInType(Defines.Symbol)
367+
node.elements.map {
368+
case element @ StaticLiteral(_) => StaticLiteral(argumentsType)(element.span)
369+
case element @ DynamicLiteral(_, expressions) => DynamicLiteral(argumentsType, expressions)(element.span)
370+
case element => element
371+
}
372+
} else {
373+
node.elements
374+
}
375+
val argumentAsts = arguments.zipWithIndex.map { case (arg, idx) =>
376+
val indices = StaticLiteral(getBuiltInType(Defines.Integer))(arg.span.spanStart(idx.toString)) :: Nil
377+
val base = tmpRubyNode(Option(arg))
378+
val indexAccess = IndexAccess(base, indices)(arg.span.spanStart(s"${base.text}[$idx]"))
379+
val assignment = SingleAssignment(indexAccess, "=", arg)(arg.span.spanStart(s"${indexAccess.text} = ${arg.text}"))
380+
astForExpression(assignment)
381+
}
382+
383+
val arrayInitCall = {
384+
val base = SimpleIdentifier()(node.span.spanStart(Defines.Array))
385+
astForExpression(SimpleObjectInstantiation(base, Nil)(node.span))
386+
}
387+
388+
val assignment =
389+
callNode(node, code(node), Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH)
390+
val tmpAssignment = callAst(assignment, tmpAst() :: arrayInitCall :: Nil)
391+
val tmpRetAst = tmpAst(node.elements.lastOption)
392+
393+
scope.popScope()
394+
blockAst(block, tmpAssignment +: argumentAsts :+ tmpRetAst)
348395
}
349396

350397
private def astForOperatorAssignmentExpression(node: OperatorAssignment): Ast = {

joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
6363
case node: AccessModifier => astForSimpleIdentifier(node.toSimpleIdentifier)
6464
case node: ArrayPattern => astForArrayPattern(node)
6565
case node: DummyNode => Ast(node.node)
66+
case node: DummyAst => node.ast
6667
case node: Unknown => astForUnknown(node)
6768
case x =>
6869
logger.warn(s"Unhandled expression of type ${x.getClass.getSimpleName}")

joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.joern.rubysrc2cpg.astcreation
22

33
import io.joern.rubysrc2cpg.passes.{Defines, GlobalTypes}
4+
import io.joern.x2cpg.Ast
45
import io.shiftleft.codepropertygraph.generated.nodes.NewNode
56

67
import java.util.Objects
@@ -621,6 +622,10 @@ object RubyIntermediateAst {
621622
*/
622623
final case class DummyNode(node: NewNode)(span: TextSpan) extends RubyExpression(span)
623624

625+
/** A dummy class for wrapping around `Ast` and allowing it to integrate with RubyNode classes.
626+
*/
627+
final case class DummyAst(ast: Ast)(span: TextSpan) extends RubyExpression(span)
628+
624629
final case class UnaryExpression(op: String, expression: RubyExpression)(span: TextSpan) extends RubyExpression(span)
625630

626631
final case class BinaryExpression(lhs: RubyExpression, op: String, rhs: RubyExpression)(span: TextSpan)

joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package io.joern.rubysrc2cpg.querying
22

3-
import io.joern.rubysrc2cpg.passes.Defines.RubyOperators
43
import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
54
import io.shiftleft.codepropertygraph.generated.Operators
65
import io.shiftleft.codepropertygraph.generated.nodes.*
@@ -115,11 +114,11 @@ class CaseTests extends RubyCode2CpgFixture {
115114

116115
val block @ List(_) = cpg.method.name("class_for").block.astChildren.isBlock.l
117116

118-
val List(assign) = block.astChildren.assignment.l;
117+
val assign = block.astChildren.assignment.head
119118
val List(lhs, rhs) = assign.argument.l
120119

121120
lhs.start.isIdentifier.name.l shouldBe List("<tmp-0>")
122-
rhs.start.isCall.code.l shouldBe List("[type, location]")
121+
rhs.start.isBlock.code.l shouldBe List("[type, location]") // array lowering
123122

124123
val headIf @ List(_) = block.astChildren.isControlStructure.l
125124
val ifStmts @ List(_, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;
@@ -165,11 +164,11 @@ class CaseTests extends RubyCode2CpgFixture {
165164

166165
val block @ List(_) = cpg.method.name("class_for").block.astChildren.isBlock.l
167166

168-
val List(assign, _, _) = block.astChildren.assignment.l;
169-
val List(lhs, rhs) = assign.argument.l
167+
val assign = block.astChildren.assignment.head
168+
val List(lhs, rhs) = assign.argument.l
170169

171170
lhs.start.isIdentifier.name.l shouldBe List("<tmp-0>")
172-
rhs.start.isCall.code.l shouldBe List("[type, location]")
171+
rhs.start.isBlock.code.l shouldBe List("[type, location]") // where the array lowering happens
173172

174173
val headIf @ List(_) = block.astChildren.isControlStructure.l
175174
val ifStmts @ List(_, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;

0 commit comments

Comments
 (0)