@@ -69,13 +69,8 @@ const Pi* isReturning(const Pi* pi){
69
69
return nullptr ;
70
70
}
71
71
72
- DefArray vars_without_mem_cont (World& world, Lam* lam) {
73
- return {
74
- lam->num_vars ()-( isReturning (lam->type ()) == nullptr ? 1 : 2 ),
75
- [&](auto i) {
76
- return lam->var (i+1 );
77
- }
78
- };
72
+ DefArray vars_without_mem_cont (Lam* lam) {
73
+ return lam->vars ().skip (1 , isReturning (lam->type ()) != nullptr );
79
74
}
80
75
// multidimensional addition of values
81
76
// needed for operation differentiation
@@ -100,7 +95,7 @@ const Lam* vec_add(World& world, const Def* a, const Def* b, const Def* cont) {
100
95
auto sum_cont = vec_add (world,a_v,b_v,res_cont);
101
96
sum_pb->set_body (world.app (sum_cont, mem3));
102
97
auto rmem=res_cont->mem_var ();
103
- auto s_v= world.tuple (vars_without_mem_cont (world, res_cont));
98
+ auto s_v= world.tuple (vars_without_mem_cont (res_cont));
104
99
auto [rmem2, sum_ptr]=world.op_slot (ty,rmem,world.dbg (" add_slot" ))->projs <2 >();
105
100
auto rmem3 = world.op_store (rmem2,sum_ptr,s_v);
106
101
@@ -151,7 +146,7 @@ const Lam* vec_add(World& world, const Def* a, const Def* b, const Def* cont) {
151
146
auto elem_res_cont_type = world.cn_mem_flat (a_v->type ());
152
147
auto elem_res_cont = world.nom_filter_lam (elem_res_cont_type,world.dbg (" tuple_add_cont" ));
153
148
auto element_sum_pb = vec_add (world,a_v,b_v,elem_res_cont);
154
- auto c_v = world.tuple (vars_without_mem_cont (world, elem_res_cont));
149
+ auto c_v = world.tuple (vars_without_mem_cont (elem_res_cont));
155
150
auto res_mem=elem_res_cont->mem_var ();
156
151
res_mem=world.op_store (res_mem,c_p,c_v);
157
152
@@ -208,7 +203,7 @@ const Lam* vec_add(World& world, const Def* a, const Def* b, const Def* cont) {
208
203
auto res_cont_type = world.cn_mem_flat (ai->type ());
209
204
auto res_cont = world.nom_filter_lam (res_cont_type,world.dbg (" tuple_add_cont" ));
210
205
auto sum_call=vec_add (world,ai,bi,res_cont);
211
- ops[i]=world.tuple (vars_without_mem_cont (world, res_cont));
206
+ ops[i]=world.tuple (vars_without_mem_cont (res_cont));
212
207
213
208
current_cont->set_body (world.app (
214
209
sum_call,
@@ -278,22 +273,19 @@ std::pair<const Def*,const Def*> lit_of_type(World& world, const Def* mem, const
278
273
litdef= world.lit_real (as_lit (real->arg ()), lit);
279
274
else if (auto a = type->isa <Arr>()) {
280
275
auto dim = a->shape ()->as <Lit>()->get <uint8_t >();
281
- DefArray ops{dim};
282
- for (size_t i = 0 ; i < dim; ++i) {
283
- auto [nmem, op]=lit_of_type (world,mem,a->body (),like,lit,dummy);
276
+ DefArray ops{dim, [&](auto ){
277
+ auto [nmem, op] = lit_of_type (world,mem,a->body (),like,lit,dummy);
284
278
mem=nmem;
285
- ops[i]= op;
286
- }
279
+ return op;
280
+ }};
287
281
litdef= world.tuple (ops);
288
282
}else if (auto sig = type->isa <Sigma>()) {
289
- std::vector<const Def*> zops;
290
- int idx=0 ;
291
- for (auto op : sig->ops ()) {
292
- auto [nmem, zop]=lit_of_type (world,mem,op,like->proj (idx),lit,dummy);
283
+ auto zops = sig->ops ().map ([&](auto op, auto index){
284
+ auto [nmem, zop]=lit_of_type (world,mem,op,like->proj (index),lit,dummy);
293
285
mem=nmem;
294
- zops. push_back ( zop) ;
295
- idx++ ;
296
- }
286
+ return zop;
287
+ }) ;
288
+
297
289
litdef= world.tuple (zops);
298
290
}
299
291
else litdef= dummy;
@@ -447,22 +439,12 @@ const Def* AutoDiffer::j_wrap_tuple(DefArray tuple) {
447
439
// apply them with the component of the scalar from the tuple pullback
448
440
// sum them up
449
441
450
- size_t real_arg_num;
451
- if (isRetTuple)
452
- real_arg_num=tuple_dim-2 ;
453
- else if (isMemTuple)
454
- real_arg_num=tuple_dim-1 ;
455
- else
456
- real_arg_num=tuple_dim;
457
-
458
- // const Def* trimmed_ty;
459
- // auto tuple_ty = tuple->type();
460
- auto trimmed_var_ty=DefArray (real_arg_num,
461
- [&] (auto i) {
462
- return tuple[isMemTuple ? i+1 : i]->type ();
463
- });
464
-
465
- auto trimmed_ty=world_.sigma (trimmed_var_ty);
442
+ auto trimmed_tuple = tuple.skip (isMemTuple, isRetTuple);
443
+ auto trimed_ops = ops.skip (isMemTuple, isRetTuple);
444
+
445
+ auto trimmed_ty=world_.sigma (
446
+ trimmed_tuple.map ( [] (auto * def, auto ) { return def->type (); } )
447
+ );
466
448
auto pi = createPbType (A,trimmed_ty);
467
449
auto pb = world_.nom_filter_lam (pi, world_.dbg (" tuple_pb" ));
468
450
auto pbT = pi->as <Pi>()->doms ().back ()->as <Pi>();
@@ -472,27 +454,18 @@ const Def* AutoDiffer::j_wrap_tuple(DefArray tuple) {
472
454
flat_tuple ({
473
455
pb->mem_var (),
474
456
zero_grad
475
- }) ));
476
-
477
- auto tuple_of_pb = world_.tuple (
478
- DefArray{real_arg_num, [&](auto i) { return pullbacks_[isMemTuple ? ops[i+1 ] : ops[i]]; }}
479
- );
457
+ })
458
+ ));
480
459
481
460
/* *
482
461
* pb = \lambda mem scalars ret. sum_pb_0 (mem,0)
483
462
* sum_pb_i = \lambda mem sum_i. pb_i (mem, s_i, res_pb_i)
484
463
* res_pb_i = \lambda mem res_i. sum_cont (mem, sum_i, res_i, sum_pb_{i+1})
485
464
* sum_pb_n = \lambda mem sum. ret (mem, sum)
486
465
*/
487
- for (size_t i = 0 ; i < real_arg_num; ++i) {
488
-
489
- const Def* op;
490
- if (isMemTuple) {
491
- op=ops[i+1 ];
492
- }else {
493
- op=ops[i];
494
- }
495
- auto op_pb=pullbacks_[op];
466
+ for (size_t i = 0 ; i < trimed_ops.size (); ++i) {
467
+ const Def* op = trimed_ops[i];
468
+ auto op_pb = pullbacks_[op];
496
469
auto scalar = pb->var (i+1 , world_.dbg (" s" ));
497
470
498
471
auto res_pb = world_.nom_filter_lam (pbT, world_.dbg (" res_pb" ));
@@ -502,13 +475,14 @@ const Def* AutoDiffer::j_wrap_tuple(DefArray tuple) {
502
475
current_sum_pb->mem_var (),
503
476
scalar,
504
477
res_pb
505
- })));
478
+ })
479
+ ));
506
480
507
481
auto next_current_sum_pb = world_.nom_filter_lam (pbT, world_.dbg (" tuple_sum_pb" ));
508
482
509
483
auto sum_cont_pb = vec_add (world_,
510
- world_.tuple (vars_without_mem_cont (world_, current_sum_pb)),
511
- world_.tuple (vars_without_mem_cont (world_, res_pb)),
484
+ world_.tuple (vars_without_mem_cont (current_sum_pb)),
485
+ world_.tuple (vars_without_mem_cont (res_pb)),
512
486
next_current_sum_pb);
513
487
res_pb->set_body (world_.app (
514
488
sum_cont_pb,
@@ -546,8 +520,8 @@ const Def* AutoDiffer::chain(const Def* a, const Def* b) {
546
520
auto middlepi = world_.cn_mem_flat (B);
547
521
auto middle = world_.nom_filter_lam (middlepi, world_.dbg (" chain_2" ));
548
522
549
- toplevel->set_body (world_.app (a, flat_tuple ({toplevel->mem_var (), world_.tuple (vars_without_mem_cont (world_, toplevel)), middle})));
550
- middle->set_body (world_.app (b, flat_tuple ({middle->mem_var (), world_.tuple (vars_without_mem_cont (world_, middle)), toplevel->ret_var ()})));
523
+ toplevel->set_body (world_.app (a, flat_tuple ({toplevel->mem_var (), world_.tuple (vars_without_mem_cont (toplevel)), middle})));
524
+ middle->set_body (world_.app (b, flat_tuple ({middle->mem_var (), world_.tuple (vars_without_mem_cont (middle)), toplevel->ret_var ()})));
551
525
552
526
return toplevel;
553
527
}
@@ -602,7 +576,7 @@ const Def* AutoDiffer::extract_pb(const Def* j_extract, const Def* tuple) {
602
576
else if (i==dim-1 ) {
603
577
args[i]=pb->ret_var ();
604
578
} else if (i==index_lit) {
605
- args[i]= world_.tuple (vars_without_mem_cont (world_, pb));
579
+ args[i]= world_.tuple (vars_without_mem_cont (pb));
606
580
}else {
607
581
// TODO: correct index
608
582
auto [nmem, v]=ZERO (world_,mem,pb_domain->op (i), tuple->proj (i));
@@ -612,7 +586,6 @@ const Def* AutoDiffer::extract_pb(const Def* j_extract, const Def* tuple) {
612
586
}
613
587
args[0 ]=mem;
614
588
pb_args=args;
615
-
616
589
}else {
617
590
auto [rmem, ohv] = oneHot (world_,pb->mem_var (), idx,world_.tangent_type (tuple_ty,false ),nullptr ,pb->var (1 ,world_.dbg (" s" )));
618
591
pb_args=
@@ -625,7 +598,7 @@ const Def* AutoDiffer::extract_pb(const Def* j_extract, const Def* tuple) {
625
598
pb->set_body (world_.app (
626
599
tuple_pb,
627
600
pb_args
628
- ));
601
+ ));
629
602
return pb;
630
603
}
631
604
// loads pb from shadow slot, updates pb for the ptr, returns, mem and pb for the loaded value
@@ -645,41 +618,23 @@ const Def* AutoDiffer::reverse_diff(Lam* src) {
645
618
auto dst_var = src_to_dst_[src_var];
646
619
auto var_sigma = src_var->type ()->as <Sigma>();
647
620
648
- auto size = var_sigma->num_ops () - 2 ;
649
- DefArray trimmed_var_ty (size);
650
- for (size_t i = 0 ; i < size; ++i) {
651
- trimmed_var_ty[i] = var_sigma->op (i+1 );
652
- }
621
+ DefArray trimmed_var_ty = var_sigma->ops ().skip ();
653
622
auto trimmed_var_sigma = world_.sigma (trimmed_var_ty);
654
623
auto idpi = createPbType (A,trimmed_var_sigma);
655
624
auto idpb = world_.nom_filter_lam (idpi, world_.dbg (" param_id" ));
656
- auto real_params = DefArray (
657
- dst_lam->num_vars ()-2 ,
658
- [&](auto i) {
659
- return dst_lam->var (i+1 );
660
- });
625
+ auto real_params = dst_lam->vars ().skip ();
661
626
auto [current_mem_,zero_grad_] = ZERO (world_,current_mem,A,world_.tuple (real_params));
662
627
current_mem=current_mem_;
663
628
zero_grad=zero_grad_;
664
629
// ret only resp. non-mem, non-cont
665
- auto args = DefArray (
666
- src->num_vars ()-1 ,
667
- [&](auto i) {
668
- if (i==0 )
669
- return idpb->mem_var ();
670
- return idpb->var (i);
671
- });
630
+ auto args = idpb->vars ().skip_back ();
672
631
idpb->set_body (world_.app (idpb->ret_var (), args));
673
632
pullbacks_[dst_var] = idpb;
674
- for (size_t i = 0 , e = src->num_vars (); i < e; ++i) {
675
- auto dvar = dst_lam->var (i);
676
- if (dvar == dst_lam->ret_var () || dvar == dst_lam->mem_var ()) {
677
- continue ;
678
- }
679
- // solve the problem of inital array pb in extract pb
680
- pullbacks_[dvar]= extract_pb (dvar, dst_lam->var ());
681
- initArg (dvar);
682
- }
633
+ for (auto dvar : src->vars ().skip ()) {
634
+ // solve the problem of inital array pb in extract pb
635
+ pullbacks_[dvar]= extract_pb (dvar, dst_lam->var ());
636
+ initArg (dvar);
637
+ }
683
638
// translate the body => get correct applications of variables using pullbacks
684
639
auto dst = j_wrap (src->body ());
685
640
return dst;
@@ -1340,12 +1295,9 @@ const Def* AutoDiffer::j_wrap_convert(const Def* def) {
1340
1295
auto m = d_arg->proj (0 );
1341
1296
auto num_projs = d_arg->num_projs ();
1342
1297
auto ret_arg = d_arg->proj (num_projs-1 );
1343
- auto args=DefArray (
1344
- num_projs-2 ,
1345
- [&](auto i) {
1346
- return d_arg->proj (i+1 );
1347
- });
1348
- auto arg= world_.tuple (args);
1298
+ auto arg= world_.tuple (
1299
+ d_arg->projs ().skip ()
1300
+ );
1349
1301
auto pbT = dst_callee->type ()->as <Pi>()->doms ().back ()->as <Pi>();
1350
1302
auto chained = world_.nom_filter_lam (pbT, world_.dbg (" φchain" ));
1351
1303
auto arg_pb = pullbacks_[d_arg]; // Lam
@@ -1356,7 +1308,7 @@ const Def* AutoDiffer::j_wrap_convert(const Def* def) {
1356
1308
ret_arg,
1357
1309
flat_tuple ({
1358
1310
chained->mem_var (),
1359
- world_.tuple (vars_without_mem_cont (world_, chained)),
1311
+ world_.tuple (vars_without_mem_cont (chained)),
1360
1312
chain_pb
1361
1313
})
1362
1314
));
@@ -1392,7 +1344,13 @@ const Def* AutoDiffer::j_wrap_convert(const Def* def) {
1392
1344
ad_args = world_.tuple (
1393
1345
DefArray (
1394
1346
count+1 ,
1395
- [&](auto i) {if (i<count) {return world_.extract (d_arg, (u64 )i, world_.dbg (" ad_arg" ));} else {return pullbacks_[d_arg];}}
1347
+ [&](auto i) {
1348
+ if (i<count) {
1349
+ return world_.extract (d_arg, (u64 )i, world_.dbg (" ad_arg" ));
1350
+ } else {
1351
+ return pullbacks_[d_arg];
1352
+ }
1353
+ }
1396
1354
));
1397
1355
}else {
1398
1356
// var (lambda completely with all arguments) and other (non tuple)
@@ -1414,7 +1372,7 @@ const Def* AutoDiffer::j_wrap_convert(const Def* def) {
1414
1372
auto dim = as_lit (pack->type ()->arity ());
1415
1373
auto tup=DefArray (
1416
1374
dim,
1417
- [&](auto i ) {
1375
+ [&](auto ) {
1418
1376
return pack->body ();
1419
1377
});
1420
1378
return j_wrap_tuple (tup);
@@ -1555,8 +1513,8 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
1555
1513
THORIN_UNREACHABLE;
1556
1514
}
1557
1515
1558
- auto adiff = world_.tuple (vars_without_mem_cont (world_, middle));
1559
- auto bdiff = world_.tuple (vars_without_mem_cont (world_, end));
1516
+ auto adiff = world_.tuple (vars_without_mem_cont (middle));
1517
+ auto bdiff = world_.tuple (vars_without_mem_cont (end));
1560
1518
auto sum_pb=vec_add (world_,adiff,bdiff,pb->ret_var ());
1561
1519
end->set_body (world_.app (sum_pb, end->mem_var ()));
1562
1520
pullbacks_[dst] = pb;
0 commit comments