@@ -13,18 +13,13 @@ object ArityRaising extends Phase[CoreTransformed, CoreTransformed] {
1313 override val phaseName : String = " arity raising"
1414
1515 override def run (input : CoreTransformed )(using C : Context ): Option [CoreTransformed ] = input match {
16- case CoreTransformed (source, tree, mod, core) => {
16+ case CoreTransformed (source, tree, mod, core) =>
1717 implicit val pctx : DeclarationContext = new DeclarationContext (core.declarations, core.externs)
1818 Context .module = mod
1919 val main = C .ensureMainExists(mod)
2020 val res = Deadcode .remove(main, core)
21- // println("Before")
22- // println(PrettyPrinter.format(res))
2321 val transformed = Context .timed(phaseName, source.name) { transform(res) }
24- // println("\n\n\n\nhello")
25- // println(PrettyPrinter.format(transformed))
2622 Some (CoreTransformed (source, tree, mod, transformed))
27- }
2823 }
2924
3025 def transform (decl : ModuleDecl )(using Context , DeclarationContext ): ModuleDecl = decl match {
@@ -36,11 +31,10 @@ object ArityRaising extends Phase[CoreTransformed, CoreTransformed] {
3631 case Toplevel .Def (id, block) => Toplevel .Def (id, transform(block)(using C , DC , Set .empty))
3732 case Toplevel .Val (id, binding) => Toplevel .Val (id, transform(binding)(using C , DC , Set .empty))
3833 }
39-
40- def transform (block : Block )(using C : Context , DC : DeclarationContext , bargs : Set [Id ]): Block = block match {
34+
35+ def transform (block : Block )(using C : Context , DC : DeclarationContext , boundBlockParams : Set [Id ]): Block = block match {
4136 case Block .BlockVar (id, annotatedTpe, annotatedCapt) => block
4237 case Block .BlockLit (tparams, cparams, vparams, bparams, body) =>
43- val newBargs = bargs ++ bparams.map(_.id)
4438 def flattenParam (param : ValueParam ): (List [ValueParam ], List [(Id , Expr )]) = param match {
4539 case ValueParam (paramId, tpe @ ValueType .Data (name, targs)) =>
4640 DC .findData(name) match {
@@ -53,62 +47,70 @@ object ArityRaising extends Phase[CoreTransformed, CoreTransformed] {
5347
5448 val binding = (paramId, Make (tpe, ctor, List (), fieldVars))
5549 (flatParams.flatten, allBindings.flatten :+ binding)
56-
57- case _ => (List (param), List ())
50+
51+ case _ => (List (param), List ())
5852 }
5953 case _ => (List (param), List ())
6054 }
6155
6256 val flattened = vparams.map(flattenParam)
6357 val (allParams, allBindings) = flattened.unzip
64-
65- val newBody = allBindings.flatten.foldRight(transform(body)(using C , DC , newBargs )) {
58+
59+ val newBody = allBindings.flatten.foldRight(transform(body)(using C , DC , boundBlockParams ++ bparams.map(_.id) )) {
6660 case ((id, expr), body) => Let (id, expr, body)
6761 }
68-
69- Block .BlockLit (tparams, cparams, allParams.flatten, bparams, newBody)
70- case Block .Unbox (pure) => Block .Unbox (transform(pure))
71- case Block .New (Implementation (interface, operations)) =>
62+
63+ Block .BlockLit (tparams, cparams, allParams.flatten, bparams, newBody)
64+
65+ case Block .Unbox (pure) =>
66+ Block .Unbox (transform(pure))
67+
68+ case Block .New (Implementation (interface, operations)) =>
7269 Block .New (Implementation (interface, operations.map {
7370 case Operation (name, tparams, cparams, vparams, bparams, body) =>
74- val opBargs = bargs ++ bparams.map(_.id)
75- Operation (name, tparams, cparams, vparams, bparams, transform(body)(using C , DC , opBargs))
76- }))
71+ Operation (name, tparams, cparams, vparams, bparams, transform(body)(using C , DC , boundBlockParams ++ bparams.map(_.id)))
72+ }))
7773 }
78- // Helper to check if a type needs flattening
79- def needsFlattening (tpe : ValueType )(using DC : DeclarationContext ): Boolean = tpe match {
80- case ValueType .Data (name, _) =>
81- DC .findData(name) match {
82- case Some (Data (_, List (), List (Constructor (_, List (), _)))) => true
83- case _ => false
84- }
74+
75+ // Helper to check if a type needs flattening
76+ def needsFlattening (tpe : ValueType )(using DC : DeclarationContext ): Boolean = tpe match {
77+ case ValueType .Data (name, _) =>
78+ DC .findData(name) match {
79+ case Some (Data (_, List (), List (Constructor (_, List (), _)))) => true
8580 case _ => false
8681 }
82+ case _ => false
83+ }
84+
85+ def wrapBlockVarIfNeeded (barg : BlockVar , annotatedTpe : BlockType )(using C : Context , DC : DeclarationContext , boundBlockParams : Set [Id ]): Block =
86+ annotatedTpe match {
87+ case BlockType .Function (tparams, cparams, vparams, bparamTpes, result) if vparams.exists(needsFlattening) =>
88+ val values = vparams.map { tpe =>
89+ val freshId = Id (" x" )
90+ (ValueParam (freshId, tpe), ValueVar (freshId, tpe))
91+ }
92+ val blocks = bparamTpes.zip(cparams).map { case (tpe, capt) =>
93+ val freshId = Id (" f" )
94+ (BlockParam (freshId, tpe, Set (capt)), BlockVar (freshId, tpe, Set (capt)))
95+ }
96+ val call = Stmt .App (barg, List (), values.map(_._2), blocks.map(_._2))
97+ BlockLit (tparams, cparams, values.map(_._1), blocks.map(_._1), transform(call)(using C , DC , boundBlockParams ++ blocks.map(_._1.id)))
8798
88- def transform (stmt : Stmt )(using C : Context , DC : DeclarationContext , bargs : Set [Id ]): Stmt = stmt match {
89- case Stmt .App (callee @ BlockVar (id, BlockType .Function (tparams, cparams, vparamsTypes, bparamTypes, returnTpe), annotatedCapt), targs, vargs, appBargs)
90- if ! bargs.contains(id) =>
99+ case _ => transform(barg)
100+ }
101+
102+ def transform (stmt : Stmt )(using C : Context , DC : DeclarationContext , boundBlockParams : Set [Id ]): Stmt = stmt match {
103+ case Stmt .App (callee @ BlockVar (id, BlockType .Function (tparams, cparams, vparamsTypes, bparamTypes, returnTpe), annotatedCapt), targs, vargs, bargs) if ! boundBlockParams.contains(id) =>
91104 def flattenArg (arg : Expr , argType : ValueType ): (List [Expr ], List [ValueType ], List [(Expr , Id , List [ValueParam ])]) = argType match {
92105 case ValueType .Data (name, targs) =>
93106 DC .findData(name) match {
94107 case Some (Data (_, List (), List (Constructor (ctor, List (), fields)))) =>
95- val fieldInfo = fields.map { case Field (fieldName, fieldType) =>
96- val freshId = Id (fieldName)
97- val freshVar = ValueVar (freshId, fieldType)
98- val freshParam = ValueParam (freshId, fieldType)
99-
100- val (nestedVars, nestedTypes, nestedMatches) = flattenArg(freshVar, fieldType)
101- (freshVar, freshParam, fieldType, nestedVars, nestedTypes, nestedMatches)
102- }
103-
104- val vars = fieldInfo.flatMap(_._4)
105- val types = fieldInfo.flatMap(_._5)
106- val params = fieldInfo.map(_._2)
107- val nestedMatches = fieldInfo.flatMap(_._6)
108- val thisMatch = (arg, ctor, params)
109-
110- (vars, types, thisMatch :: nestedMatches)
111-
108+ val fieldParams = fields.map { case Field (name, tpe) => ValueParam (Id (name), tpe) }
109+ val nestedResults = fieldParams.map { param => flattenArg(ValueVar (param.id, param.tpe), param.tpe) }
110+ val (nestedVars, nestedTypes, nestedMatches) = nestedResults.unzip3
111+ val thisMatch = (arg, ctor, fieldParams)
112+ (nestedVars.flatten, nestedTypes.flatten, thisMatch :: nestedMatches.flatten)
113+
112114 case _ => (List (arg), List (argType), List ())
113115 }
114116 case _ => (List (arg), List (argType), List ())
@@ -117,183 +119,110 @@ object ArityRaising extends Phase[CoreTransformed, CoreTransformed] {
117119 val flattened = (vargs zip vparamsTypes).map { case (arg, tpe) => flattenArg(arg, tpe) }
118120 val (allArgs, allTypes, allMatches) = flattened.unzip3
119121
120-
121-
122- val transformedBargs = appBargs.map { barg =>
122+ val transformedBargs = bargs.map { barg =>
123123 barg match {
124- // This handles:
125- // val res = myList.map {myFunc}
124+ // This handles:
125+ // val res = myList.map {myFunc}
126126 // by making it:
127127 // val res = myList.map {t => myFunc(t)}
128128 // but only if the arity of myFunc changes
129- case BlockVar (id, annotatedTpe, annotatedCapt) =>
130- annotatedTpe match {
131- case BlockType .Function (tparams, cparams, vparams, bparamTpes, result)
132- if vparams.exists(needsFlattening) =>
133- val values = vparams.map { tpe =>
134- val freshId = Id (" x" )
135- (ValueParam (freshId, tpe), ValueVar (freshId, tpe))
136- }
137- val blocks = bparamTpes.zip(cparams).map { case (tpe, capt) =>
138- val freshId = Id (" f" )
139- (BlockParam (freshId, tpe, Set (capt)), BlockVar (freshId, tpe, Set (capt)))
140- }
141-
142- // Don't transform the call - the BlockVar keeps its original signature
143- val call = Stmt .App (barg, List (), values.map(_._2), blocks.map(_._2))
144-
145- val wrapperBargs = bargs ++ blocks.map(_._1.id)
146- BlockLit (tparams, cparams, values.map(_._1), blocks.map(_._1), transform(call)(using C , DC , wrapperBargs))
147-
148- case _ => transform(barg)
149- }
129+ case bvar @ BlockVar (id, annotatedTpe, annotatedCapt) =>
130+ wrapBlockVarIfNeeded(bvar, annotatedTpe)
150131
151132 case BlockLit (btparams, bcparams, bvparams, bbparams, body) =>
152- // Keep the signature unchanged
153- val litBargs = bargs ++ bbparams.map(_.id)
154- val transformedBody = transform(body)(using C , DC , litBargs)
155- BlockLit (btparams, bcparams, bvparams, bbparams, transformedBody)
156-
157- case _ =>
133+ BlockLit (btparams, bcparams, bvparams, bbparams, transform(body)(using C , DC , boundBlockParams ++ bbparams.map(_.id)))
134+
135+ case _ =>
158136 transform(barg)
159137 }
160138 }
161-
139+
162140 val newCalleTpe : BlockType .Function = BlockType .Function (tparams, cparams, allTypes.flatten, bparamTypes, returnTpe)
163141 val newCallee = BlockVar (id, newCalleTpe, annotatedCapt)
164142 val innerApp = Stmt .App (newCallee, targs, allArgs.flatten, transformedBargs)
165-
143+
166144 allMatches.flatten.foldRight(innerApp) {
167145 case ((scrutinee, ctor, params), body) =>
168- val resultTpe = instantiate(newCalleTpe, targs, appBargs .map(_.capt)).result
146+ val resultTpe = instantiate(newCalleTpe, targs, bargs .map(_.capt)).result
169147 Stmt .Match (scrutinee, resultTpe, List ((ctor, BlockLit (List (), List (), params, List (), body))), None )
170148 }
171149
172- case Stmt .App (callee, targs, vargs, appBargs) =>
173- Stmt .App (callee, targs, vargs map transform, appBargs map transform)
150+ case Stmt .App (callee, targs, vargs, bargs) =>
151+ Stmt .App (callee, targs, vargs map transform, bargs map transform)
152+
174153 case Stmt .Def (id, block, rest) =>
175154 Stmt .Def (id, transform(block), transform(rest))
155+
176156 case Stmt .Let (id, binding, rest) =>
177157 Stmt .Let (id, transform(binding), transform(rest))
158+
178159 case Stmt .Return (expr) =>
179160 Stmt .Return (transform(expr))
161+
180162 case Stmt .Val (id, binding, body) =>
181163 Stmt .Val (id, transform(binding), transform(body))
182- case Stmt .Invoke (callee, method, methodTpe, targs, vargs, invokeBargs) =>
183- Stmt .Invoke (transform(callee), method, methodTpe, targs, vargs map transform, invokeBargs map transform)
164+
165+ case Stmt .Invoke (callee, method, methodTpe, targs, vargs, bargs) =>
166+ Stmt .Invoke (transform(callee), method, methodTpe, targs, vargs map transform, bargs map transform)
167+
184168 case Stmt .If (cond, thn, els) =>
185169 Stmt .If (transform(cond), transform(thn), transform(els))
186170 case Stmt .Match (scrutinee, tpe, clauses, default) =>
187171 Stmt .Match (transform(scrutinee), tpe, clauses.map { case (id, BlockLit (tparams, cparams, vparams, bparams, body)) =>
188- val clauseBargs = bargs ++ bparams.map(_.id)
189- (id, BlockLit (tparams, cparams, vparams, bparams, transform(body)(using C , DC , clauseBargs)))
172+ (id, BlockLit (tparams, cparams, vparams, bparams, transform(body)(using C , DC , boundBlockParams ++ bparams.map(_.id))))
190173 }, default map transform)
191- case Stmt .ImpureApp (id, callee, targs, vargs, impureBargs, body) =>
192- Stmt .ImpureApp (id, callee, targs, vargs map transform, impureBargs map transform, transform(body))
174+
175+ case Stmt .ImpureApp (id, callee, targs, vargs, bargs, body) =>
176+ Stmt .ImpureApp (id, callee, targs, vargs map transform, bargs map transform, transform(body))
177+
193178 case Stmt .Region (BlockLit (tparams, cparams, vparams, bparams, body)) =>
194- val regionBargs = bargs ++ bparams.map(_.id)
195- Stmt . Region ( BlockLit (tparams, cparams, vparams, bparams, transform(body)( using C , DC , regionBargs)))
179+ Stmt . Region ( BlockLit (tparams, cparams, vparams, bparams, transform(body)( using C , DC , boundBlockParams ++ bparams.map(_.id))) )
180+
196181 case Stmt .Alloc (id, init, region, body) =>
197182 Stmt .Alloc (id, transform(init), region, transform(body))
183+
198184 case Stmt .Var (ref, init, capture, body) =>
199185 Stmt .Var (ref, transform(init), capture, transform(body))
186+
200187 case Stmt .Get (id, annotatedTpe, ref, annotatedCapt, body) =>
201188 Stmt .Get (id, annotatedTpe, ref, annotatedCapt, transform(body))
189+
202190 case Stmt .Put (ref, annotatedCapt, value, body) =>
203191 Stmt .Put (ref, annotatedCapt, transform(value), transform(body))
192+
204193 case Stmt .Reset (BlockLit (tparams, cparams, vparams, bparams, body)) =>
205- val resetBargs = bargs ++ bparams.map(_.id)
206- Stmt . Reset ( BlockLit (tparams, cparams, vparams, bparams, transform(body)( using C , DC , resetBargs)))
194+ Stmt . Reset ( BlockLit (tparams, cparams, vparams, bparams, transform(body)( using C , DC , boundBlockParams ++ bparams.map(_.id))) )
195+
207196 case Stmt .Shift (prompt, k, body) =>
208- // k is a continuation (block param), so add it to bargs
209- val shiftBargs = bargs + k.id
210- Stmt . Shift (prompt, k, transform(body)( using C , DC , shiftBargs))
197+ // k is a continuation (block param), so add it to boundBlockParams
198+ Stmt . Shift (prompt, k, transform(body)( using C , DC , boundBlockParams + k.id))
199+
211200 case Stmt .Resume (k, body) =>
212201 Stmt .Resume (k, transform(body))
202+
213203 case Stmt .Hole (tpe, span) =>
214204 Stmt .Hole (tpe, span)
215205 }
216206
217- def transform (pure : Expr )(using C : Context , DC : DeclarationContext , bargs : Set [Id ]): Expr = pure match {
207+ def transform (pure : Expr )(using C : Context , DC : DeclarationContext , boundBlockParams : Set [Id ]): Expr = pure match {
218208 case Expr .ValueVar (id, annotatedType) => pure
219- case Expr .Literal (value, annotatedType) => pure
220- case Expr .Box (barg @ BlockVar (id, annotatedTpe, annotatedCapt), annotatedCapture) =>
221- annotatedTpe match {
222- case BlockType .Function (tparams, cparams, vparams, bparamTpes, result)
223- if vparams.exists(needsFlattening) =>
224- val values = vparams.map { tpe =>
225- val freshId = Id (" x" )
226- (ValueParam (freshId, tpe), ValueVar (freshId, tpe))
227- }
228- val blocks = bparamTpes.zip(cparams).map { case (tpe, capt) =>
229- val freshId = Id (" f" )
230- (BlockParam (freshId, tpe, Set (capt)), BlockVar (freshId, tpe, Set (capt)))
231- }
232-
233- // Don't transform the call - the BlockVar keeps its original signature
234- val call = Stmt .App (barg, List (), values.map(_._2), blocks.map(_._2))
235-
236- val wrapperBargs = bargs ++ blocks.map(_._1.id)
237- Expr .Box (BlockLit (tparams, cparams, values.map(_._1), blocks.map(_._1), transform(call)(using C , DC , wrapperBargs)), annotatedCapture)
238-
239- case _ => Expr .Box (transform(barg), annotatedCapture)
240- }
241-
242- case Expr .Box (b, annotatedCapture) => Expr .Box (transform(b), annotatedCapture)
243- case Expr .PureApp (b, targs, vargs) =>
244- Expr .PureApp (b, targs, vargs map transform)
245- case Expr .Make (data, tag, targs, vargs) =>
246- Expr .Make (data, tag, targs, vargs map transform)
247- }
248209
249- def transform (valueType : ValueType .Data )(using C : Context , DC : DeclarationContext ): ValueType .Data = valueType match {
250- case ValueType .Data (symbol, targs) => valueType // trainsform
251- }
252-
253- def doIndentation (input : String ): String = {
254- val sb = new StringBuilder
255- var indent = 0
256- var i = 0
257-
258- while (i < input.length) {
259- input(i) match {
260- case '(' =>
261- // Look ahead to see if it's a short List(...) with no commas
262- val closing = input.indexOf(')' , i)
263- val inside = if (closing > i) input.substring(i + 1 , closing) else " "
264- if (inside.contains(',' ) || inside.contains('(' ) || inside.contains(')' )) {
265- sb.append(" (\n " )
266- indent += 1
267- sb.append(" " * indent)
268- } else {
269- sb.append('(' )
270- }
271-
272- case ')' =>
273- val prev = if (i > 0 ) input(i - 1 ) else ' '
274- if (prev == '(' || prev.isLetterOrDigit) {
275- sb.append(')' )
276- } else {
277- sb.append(" \n " )
278- indent -= 1
279- sb.append(" " * indent)
280- sb.append(" )" )
281- }
210+ case Expr .Literal (value, annotatedType) => pure
282211
283- case ',' =>
284- sb.append(" ,\n " )
285- sb.append(" " * indent)
212+ case Expr .Box (bvar @ BlockVar (id, annotatedTpe, annotatedCapt), annotatedCapture) =>
213+ Expr .Box (wrapBlockVarIfNeeded(bvar, annotatedTpe), annotatedCapture)
286214
287- case c if c.isWhitespace =>
288- // skip
215+ case Expr . Box (b, annotatedCapture) =>
216+ Expr . Box (transform(b), annotatedCapture)
289217
290- case c =>
291- sb.append(c)
292- }
293- i += 1
294- }
218+ case Expr .PureApp (b, targs, vargs) =>
219+ Expr .PureApp (b, targs, vargs map transform)
295220
296- sb.toString
221+ case Expr .Make (data, tag, targs, vargs) =>
222+ Expr .Make (data, tag, targs, vargs map transform)
297223 }
298224
225+ def transform (valueType : ValueType .Data )(using C : Context , DC : DeclarationContext ): ValueType .Data = valueType match {
226+ case ValueType .Data (symbol, targs) => valueType
227+ }
299228}
0 commit comments