Skip to content

Commit 34b3ea6

Browse files
committed
Refactoring for ease of readabilit, formating and removing unneeded code.
1 parent 0af7d2c commit 34b3ea6

File tree

1 file changed

+100
-171
lines changed

1 file changed

+100
-171
lines changed

effekt/shared/src/main/scala/effekt/core/ArityRaising.scala

Lines changed: 100 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,13 @@ object ArityRaising extends Phase[CoreTransformed, CoreTransformed] {
1313
override val phaseName: String = "arity raising"
1414

1515
override def run(input: CoreTransformed)(using C: Context): Option[CoreTransformed] = input match {
16-
case CoreTransformed(source, tree, mod, core) => {
16+
case CoreTransformed(source, tree, mod, core) =>
1717
implicit val pctx: DeclarationContext = new DeclarationContext(core.declarations, core.externs)
1818
Context.module = mod
1919
val main = C.ensureMainExists(mod)
2020
val res = Deadcode.remove(main, core)
21-
// println("Before")
22-
// println(PrettyPrinter.format(res))
2321
val transformed = Context.timed(phaseName, source.name) { transform(res) }
24-
// println("\n\n\n\nhello")
25-
// println(PrettyPrinter.format(transformed))
2622
Some(CoreTransformed(source, tree, mod, transformed))
27-
}
2823
}
2924

3025
def transform(decl: ModuleDecl)(using Context, DeclarationContext): ModuleDecl = decl match {
@@ -36,11 +31,10 @@ object ArityRaising extends Phase[CoreTransformed, CoreTransformed] {
3631
case Toplevel.Def(id, block) => Toplevel.Def(id, transform(block)(using C, DC, Set.empty))
3732
case Toplevel.Val(id, binding) => Toplevel.Val(id, transform(binding)(using C, DC, Set.empty))
3833
}
39-
40-
def transform(block: Block)(using C: Context, DC: DeclarationContext, bargs: Set[Id]): Block = block match {
34+
35+
def transform(block: Block)(using C: Context, DC: DeclarationContext, boundBlockParams: Set[Id]): Block = block match {
4136
case Block.BlockVar(id, annotatedTpe, annotatedCapt) => block
4237
case Block.BlockLit(tparams, cparams, vparams, bparams, body) =>
43-
val newBargs = bargs ++ bparams.map(_.id)
4438
def flattenParam(param: ValueParam): (List[ValueParam], List[(Id, Expr)]) = param match {
4539
case ValueParam(paramId, tpe @ ValueType.Data(name, targs)) =>
4640
DC.findData(name) match {
@@ -53,62 +47,70 @@ object ArityRaising extends Phase[CoreTransformed, CoreTransformed] {
5347

5448
val binding = (paramId, Make(tpe, ctor, List(), fieldVars))
5549
(flatParams.flatten, allBindings.flatten :+ binding)
56-
57-
case _ => (List(param), List())
50+
51+
case _ => (List(param), List())
5852
}
5953
case _ => (List(param), List())
6054
}
6155

6256
val flattened = vparams.map(flattenParam)
6357
val (allParams, allBindings) = flattened.unzip
64-
65-
val newBody = allBindings.flatten.foldRight(transform(body)(using C, DC, newBargs)) {
58+
59+
val newBody = allBindings.flatten.foldRight(transform(body)(using C, DC, boundBlockParams ++ bparams.map(_.id))) {
6660
case ((id, expr), body) => Let(id, expr, body)
6761
}
68-
69-
Block.BlockLit(tparams, cparams, allParams.flatten, bparams, newBody)
70-
case Block.Unbox(pure) => Block.Unbox(transform(pure))
71-
case Block.New(Implementation(interface, operations)) =>
62+
63+
Block.BlockLit(tparams, cparams, allParams.flatten, bparams, newBody)
64+
65+
case Block.Unbox(pure) =>
66+
Block.Unbox(transform(pure))
67+
68+
case Block.New(Implementation(interface, operations)) =>
7269
Block.New(Implementation(interface, operations.map {
7370
case Operation(name, tparams, cparams, vparams, bparams, body) =>
74-
val opBargs = bargs ++ bparams.map(_.id)
75-
Operation(name, tparams, cparams, vparams, bparams, transform(body)(using C, DC, opBargs))
76-
}))
71+
Operation(name, tparams, cparams, vparams, bparams, transform(body)(using C, DC, boundBlockParams ++ bparams.map(_.id)))
72+
}))
7773
}
78-
// Helper to check if a type needs flattening
79-
def needsFlattening(tpe: ValueType)(using DC:DeclarationContext): Boolean = tpe match {
80-
case ValueType.Data(name, _) =>
81-
DC.findData(name) match {
82-
case Some(Data(_, List(), List(Constructor(_, List(), _)))) => true
83-
case _ => false
84-
}
74+
75+
// Helper to check if a type needs flattening
76+
def needsFlattening(tpe: ValueType)(using DC: DeclarationContext): Boolean = tpe match {
77+
case ValueType.Data(name, _) =>
78+
DC.findData(name) match {
79+
case Some(Data(_, List(), List(Constructor(_, List(), _)))) => true
8580
case _ => false
8681
}
82+
case _ => false
83+
}
84+
85+
def wrapBlockVarIfNeeded(barg: BlockVar, annotatedTpe: BlockType)(using C: Context, DC: DeclarationContext, boundBlockParams: Set[Id]): Block =
86+
annotatedTpe match {
87+
case BlockType.Function(tparams, cparams, vparams, bparamTpes, result) if vparams.exists(needsFlattening) =>
88+
val values = vparams.map { tpe =>
89+
val freshId = Id("x")
90+
(ValueParam(freshId, tpe), ValueVar(freshId, tpe))
91+
}
92+
val blocks = bparamTpes.zip(cparams).map { case (tpe, capt) =>
93+
val freshId = Id("f")
94+
(BlockParam(freshId, tpe, Set(capt)), BlockVar(freshId, tpe, Set(capt)))
95+
}
96+
val call = Stmt.App(barg, List(), values.map(_._2), blocks.map(_._2))
97+
BlockLit(tparams, cparams, values.map(_._1), blocks.map(_._1), transform(call)(using C, DC, boundBlockParams ++ blocks.map(_._1.id)))
8798

88-
def transform(stmt: Stmt)(using C: Context, DC: DeclarationContext, bargs: Set[Id]): Stmt = stmt match {
89-
case Stmt.App(callee @ BlockVar(id, BlockType.Function(tparams, cparams, vparamsTypes, bparamTypes, returnTpe), annotatedCapt), targs, vargs, appBargs)
90-
if !bargs.contains(id) =>
99+
case _ => transform(barg)
100+
}
101+
102+
def transform(stmt: Stmt)(using C: Context, DC: DeclarationContext, boundBlockParams: Set[Id]): Stmt = stmt match {
103+
case Stmt.App(callee @ BlockVar(id, BlockType.Function(tparams, cparams, vparamsTypes, bparamTypes, returnTpe), annotatedCapt), targs, vargs, bargs) if !boundBlockParams.contains(id) =>
91104
def flattenArg(arg: Expr, argType: ValueType): (List[Expr], List[ValueType], List[(Expr, Id, List[ValueParam])]) = argType match {
92105
case ValueType.Data(name, targs) =>
93106
DC.findData(name) match {
94107
case Some(Data(_, List(), List(Constructor(ctor, List(), fields)))) =>
95-
val fieldInfo = fields.map { case Field(fieldName, fieldType) =>
96-
val freshId = Id(fieldName)
97-
val freshVar = ValueVar(freshId, fieldType)
98-
val freshParam = ValueParam(freshId, fieldType)
99-
100-
val (nestedVars, nestedTypes, nestedMatches) = flattenArg(freshVar, fieldType)
101-
(freshVar, freshParam, fieldType, nestedVars, nestedTypes, nestedMatches)
102-
}
103-
104-
val vars = fieldInfo.flatMap(_._4)
105-
val types = fieldInfo.flatMap(_._5)
106-
val params = fieldInfo.map(_._2)
107-
val nestedMatches = fieldInfo.flatMap(_._6)
108-
val thisMatch = (arg, ctor, params)
109-
110-
(vars, types, thisMatch :: nestedMatches)
111-
108+
val fieldParams = fields.map { case Field(name, tpe) => ValueParam(Id(name), tpe) }
109+
val nestedResults = fieldParams.map { param => flattenArg(ValueVar(param.id, param.tpe), param.tpe) }
110+
val (nestedVars, nestedTypes, nestedMatches) = nestedResults.unzip3
111+
val thisMatch = (arg, ctor, fieldParams)
112+
(nestedVars.flatten, nestedTypes.flatten, thisMatch :: nestedMatches.flatten)
113+
112114
case _ => (List(arg), List(argType), List())
113115
}
114116
case _ => (List(arg), List(argType), List())
@@ -117,183 +119,110 @@ object ArityRaising extends Phase[CoreTransformed, CoreTransformed] {
117119
val flattened = (vargs zip vparamsTypes).map { case (arg, tpe) => flattenArg(arg, tpe) }
118120
val (allArgs, allTypes, allMatches) = flattened.unzip3
119121

120-
121-
122-
val transformedBargs = appBargs.map { barg =>
122+
val transformedBargs = bargs.map { barg =>
123123
barg match {
124-
// This handles:
125-
// val res = myList.map {myFunc}
124+
// This handles:
125+
// val res = myList.map {myFunc}
126126
// by making it:
127127
// val res = myList.map {t => myFunc(t)}
128128
// but only if the arity of myFunc changes
129-
case BlockVar(id, annotatedTpe, annotatedCapt) =>
130-
annotatedTpe match {
131-
case BlockType.Function(tparams, cparams, vparams, bparamTpes, result)
132-
if vparams.exists(needsFlattening) =>
133-
val values = vparams.map { tpe =>
134-
val freshId = Id("x")
135-
(ValueParam(freshId, tpe), ValueVar(freshId, tpe))
136-
}
137-
val blocks = bparamTpes.zip(cparams).map { case (tpe, capt) =>
138-
val freshId = Id("f")
139-
(BlockParam(freshId, tpe, Set(capt)), BlockVar(freshId, tpe, Set(capt)))
140-
}
141-
142-
// Don't transform the call - the BlockVar keeps its original signature
143-
val call = Stmt.App(barg, List(), values.map(_._2), blocks.map(_._2))
144-
145-
val wrapperBargs = bargs ++ blocks.map(_._1.id)
146-
BlockLit(tparams, cparams, values.map(_._1), blocks.map(_._1), transform(call)(using C, DC, wrapperBargs))
147-
148-
case _ => transform(barg)
149-
}
129+
case bvar @ BlockVar(id, annotatedTpe, annotatedCapt) =>
130+
wrapBlockVarIfNeeded(bvar, annotatedTpe)
150131

151132
case BlockLit(btparams, bcparams, bvparams, bbparams, body) =>
152-
// Keep the signature unchanged
153-
val litBargs = bargs ++ bbparams.map(_.id)
154-
val transformedBody = transform(body)(using C, DC, litBargs)
155-
BlockLit(btparams, bcparams, bvparams, bbparams, transformedBody)
156-
157-
case _ =>
133+
BlockLit(btparams, bcparams, bvparams, bbparams, transform(body)(using C, DC, boundBlockParams ++ bbparams.map(_.id)))
134+
135+
case _ =>
158136
transform(barg)
159137
}
160138
}
161-
139+
162140
val newCalleTpe: BlockType.Function = BlockType.Function(tparams, cparams, allTypes.flatten, bparamTypes, returnTpe)
163141
val newCallee = BlockVar(id, newCalleTpe, annotatedCapt)
164142
val innerApp = Stmt.App(newCallee, targs, allArgs.flatten, transformedBargs)
165-
143+
166144
allMatches.flatten.foldRight(innerApp) {
167145
case ((scrutinee, ctor, params), body) =>
168-
val resultTpe = instantiate(newCalleTpe, targs, appBargs.map(_.capt)).result
146+
val resultTpe = instantiate(newCalleTpe, targs, bargs.map(_.capt)).result
169147
Stmt.Match(scrutinee, resultTpe, List((ctor, BlockLit(List(), List(), params, List(), body))), None)
170148
}
171149

172-
case Stmt.App(callee, targs, vargs, appBargs) =>
173-
Stmt.App(callee, targs, vargs map transform, appBargs map transform)
150+
case Stmt.App(callee, targs, vargs, bargs) =>
151+
Stmt.App(callee, targs, vargs map transform, bargs map transform)
152+
174153
case Stmt.Def(id, block, rest) =>
175154
Stmt.Def(id, transform(block), transform(rest))
155+
176156
case Stmt.Let(id, binding, rest) =>
177157
Stmt.Let(id, transform(binding), transform(rest))
158+
178159
case Stmt.Return(expr) =>
179160
Stmt.Return(transform(expr))
161+
180162
case Stmt.Val(id, binding, body) =>
181163
Stmt.Val(id, transform(binding), transform(body))
182-
case Stmt.Invoke(callee, method, methodTpe, targs, vargs, invokeBargs) =>
183-
Stmt.Invoke(transform(callee), method, methodTpe, targs, vargs map transform, invokeBargs map transform)
164+
165+
case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) =>
166+
Stmt.Invoke(transform(callee), method, methodTpe, targs, vargs map transform, bargs map transform)
167+
184168
case Stmt.If(cond, thn, els) =>
185169
Stmt.If(transform(cond), transform(thn), transform(els))
186170
case Stmt.Match(scrutinee, tpe, clauses, default) =>
187171
Stmt.Match(transform(scrutinee), tpe, clauses.map { case (id, BlockLit(tparams, cparams, vparams, bparams, body)) =>
188-
val clauseBargs = bargs ++ bparams.map(_.id)
189-
(id, BlockLit(tparams, cparams, vparams, bparams, transform(body)(using C, DC, clauseBargs)))
172+
(id, BlockLit(tparams, cparams, vparams, bparams, transform(body)(using C, DC, boundBlockParams ++ bparams.map(_.id))))
190173
}, default map transform)
191-
case Stmt.ImpureApp(id, callee, targs, vargs, impureBargs, body) =>
192-
Stmt.ImpureApp(id, callee, targs, vargs map transform, impureBargs map transform, transform(body))
174+
175+
case Stmt.ImpureApp(id, callee, targs, vargs, bargs, body) =>
176+
Stmt.ImpureApp(id, callee, targs, vargs map transform, bargs map transform, transform(body))
177+
193178
case Stmt.Region(BlockLit(tparams, cparams, vparams, bparams, body)) =>
194-
val regionBargs = bargs ++ bparams.map(_.id)
195-
Stmt.Region(BlockLit(tparams, cparams, vparams, bparams, transform(body)(using C, DC, regionBargs)))
179+
Stmt.Region(BlockLit(tparams, cparams, vparams, bparams, transform(body)(using C, DC, boundBlockParams ++ bparams.map(_.id))))
180+
196181
case Stmt.Alloc(id, init, region, body) =>
197182
Stmt.Alloc(id, transform(init), region, transform(body))
183+
198184
case Stmt.Var(ref, init, capture, body) =>
199185
Stmt.Var(ref, transform(init), capture, transform(body))
186+
200187
case Stmt.Get(id, annotatedTpe, ref, annotatedCapt, body) =>
201188
Stmt.Get(id, annotatedTpe, ref, annotatedCapt, transform(body))
189+
202190
case Stmt.Put(ref, annotatedCapt, value, body) =>
203191
Stmt.Put(ref, annotatedCapt, transform(value), transform(body))
192+
204193
case Stmt.Reset(BlockLit(tparams, cparams, vparams, bparams, body)) =>
205-
val resetBargs = bargs ++ bparams.map(_.id)
206-
Stmt.Reset(BlockLit(tparams, cparams, vparams, bparams, transform(body)(using C, DC, resetBargs)))
194+
Stmt.Reset(BlockLit(tparams, cparams, vparams, bparams, transform(body)(using C, DC, boundBlockParams ++ bparams.map(_.id))))
195+
207196
case Stmt.Shift(prompt, k, body) =>
208-
// k is a continuation (block param), so add it to bargs
209-
val shiftBargs = bargs + k.id
210-
Stmt.Shift(prompt, k, transform(body)(using C, DC, shiftBargs))
197+
// k is a continuation (block param), so add it to boundBlockParams
198+
Stmt.Shift(prompt, k, transform(body)(using C, DC, boundBlockParams + k.id))
199+
211200
case Stmt.Resume(k, body) =>
212201
Stmt.Resume(k, transform(body))
202+
213203
case Stmt.Hole(tpe, span) =>
214204
Stmt.Hole(tpe, span)
215205
}
216206

217-
def transform(pure: Expr)(using C: Context, DC: DeclarationContext, bargs: Set[Id]): Expr = pure match {
207+
def transform(pure: Expr)(using C: Context, DC: DeclarationContext, boundBlockParams: Set[Id]): Expr = pure match {
218208
case Expr.ValueVar(id, annotatedType) => pure
219-
case Expr.Literal(value, annotatedType) => pure
220-
case Expr.Box(barg @ BlockVar(id, annotatedTpe, annotatedCapt), annotatedCapture) =>
221-
annotatedTpe match {
222-
case BlockType.Function(tparams, cparams, vparams, bparamTpes, result)
223-
if vparams.exists(needsFlattening) =>
224-
val values = vparams.map { tpe =>
225-
val freshId = Id("x")
226-
(ValueParam(freshId, tpe), ValueVar(freshId, tpe))
227-
}
228-
val blocks = bparamTpes.zip(cparams).map { case (tpe, capt) =>
229-
val freshId = Id("f")
230-
(BlockParam(freshId, tpe, Set(capt)), BlockVar(freshId, tpe, Set(capt)))
231-
}
232-
233-
// Don't transform the call - the BlockVar keeps its original signature
234-
val call = Stmt.App(barg, List(), values.map(_._2), blocks.map(_._2))
235-
236-
val wrapperBargs = bargs ++ blocks.map(_._1.id)
237-
Expr.Box(BlockLit(tparams, cparams, values.map(_._1), blocks.map(_._1), transform(call)(using C, DC, wrapperBargs)), annotatedCapture)
238-
239-
case _ => Expr.Box(transform(barg), annotatedCapture)
240-
}
241-
242-
case Expr.Box(b, annotatedCapture) => Expr.Box(transform(b), annotatedCapture)
243-
case Expr.PureApp(b, targs, vargs) =>
244-
Expr.PureApp(b, targs, vargs map transform)
245-
case Expr.Make(data, tag, targs, vargs) =>
246-
Expr.Make(data, tag, targs, vargs map transform)
247-
}
248209

249-
def transform(valueType: ValueType.Data)(using C: Context, DC: DeclarationContext): ValueType.Data = valueType match {
250-
case ValueType.Data(symbol, targs) => valueType // trainsform
251-
}
252-
253-
def doIndentation(input: String): String = {
254-
val sb = new StringBuilder
255-
var indent = 0
256-
var i = 0
257-
258-
while (i < input.length) {
259-
input(i) match {
260-
case '(' =>
261-
// Look ahead to see if it's a short List(...) with no commas
262-
val closing = input.indexOf(')', i)
263-
val inside = if (closing > i) input.substring(i + 1, closing) else ""
264-
if (inside.contains(',') || inside.contains('(') || inside.contains(')')) {
265-
sb.append("(\n")
266-
indent += 1
267-
sb.append(" " * indent)
268-
} else {
269-
sb.append('(')
270-
}
271-
272-
case ')' =>
273-
val prev = if (i > 0) input(i - 1) else ' '
274-
if (prev == '(' || prev.isLetterOrDigit) {
275-
sb.append(')')
276-
} else {
277-
sb.append("\n")
278-
indent -= 1
279-
sb.append(" " * indent)
280-
sb.append(")")
281-
}
210+
case Expr.Literal(value, annotatedType) => pure
282211

283-
case ',' =>
284-
sb.append(",\n")
285-
sb.append(" " * indent)
212+
case Expr.Box(bvar @ BlockVar(id, annotatedTpe, annotatedCapt), annotatedCapture) =>
213+
Expr.Box(wrapBlockVarIfNeeded(bvar, annotatedTpe), annotatedCapture)
286214

287-
case c if c.isWhitespace =>
288-
// skip
215+
case Expr.Box(b, annotatedCapture) =>
216+
Expr.Box(transform(b), annotatedCapture)
289217

290-
case c =>
291-
sb.append(c)
292-
}
293-
i += 1
294-
}
218+
case Expr.PureApp(b, targs, vargs) =>
219+
Expr.PureApp(b, targs, vargs map transform)
295220

296-
sb.toString
221+
case Expr.Make(data, tag, targs, vargs) =>
222+
Expr.Make(data, tag, targs, vargs map transform)
297223
}
298224

225+
def transform(valueType: ValueType.Data)(using C: Context, DC: DeclarationContext): ValueType.Data = valueType match {
226+
case ValueType.Data(symbol, targs) => valueType
227+
}
299228
}

0 commit comments

Comments
 (0)