@@ -86,15 +86,25 @@ static const Def* to_fat_ptr(World& world, const Def* ptr, const Def* size){
86
86
auto dst_fat_ptr=world.tuple ({int_size, ptr});
87
87
}
88
88
89
- Array<const Def*> vars_without_mem_cont (World& world, Lam* lam) {
89
+ const Pi* isReturning (const Pi* pi){
90
+ if (pi->is_cn () && pi->num_doms () > 0 ) {
91
+ auto ret = pi->dom (pi->num_doms () - 1 );
92
+ if (auto ret_pi = ret->isa <Pi>(); ret_pi != nullptr && ret_pi->is_cn ()) return ret_pi;
93
+ }
94
+
95
+ return nullptr ;
96
+ }
97
+
98
+ DefArray vars_without_mem_cont (World& world, Lam* lam) {
90
99
type_dump (world," get vars of" ,lam);
91
100
dlog (world," has ret_var {}" ,lam->ret_var ());
92
101
// if(lam->ret_var())
93
- return Array< const Def*>(
94
- lam->num_vars ()-(lam->ret_var ()== nullptr ? 1 : 2 ),
102
+ return {
103
+ lam->num_vars ()-( isReturning ( lam->type ()) == nullptr ? 1 : 2 ),
95
104
[&](auto i) {
96
105
return lam->var (i+1 );
97
- });
106
+ }
107
+ };
98
108
}
99
109
100
110
@@ -527,10 +537,6 @@ std::pair<const Def*,const Def*> oneHot(World& world_, const Def* mem, const Def
527
537
}
528
538
}
529
539
530
-
531
-
532
-
533
-
534
540
namespace {
535
541
536
542
class AutoDiffer {
@@ -589,14 +595,6 @@ class AutoDiffer {
589
595
const Def* chain (const Def* a, const Def* b);
590
596
const Pi* createPbType (const Def* A, const Def* B);
591
597
const Def* extract_pb (const Def* j_extract, const Def* tuple);
592
- const Def* isReturning (const Pi* def){
593
- if (def->is_cn () && def->num_doms () > 0 ) {
594
- auto ret = def->dom (def->num_doms () - 1 );
595
- if (auto pi = ret->isa <Pi>(); pi != nullptr && pi->is_cn ()) return pi;
596
- }
597
-
598
- return nullptr ;
599
- }
600
598
601
599
World& world_;
602
600
Def2Def src_to_dst_; // mapping old def to new def
@@ -2253,7 +2251,8 @@ const Def* AutoDiffer::j_wrap(const Def* def) {
2253
2251
2254
2252
auto arg_pb = pullbacks_[d_arg]; // Lam
2255
2253
type_dump (world_," arg pb" ,arg_pb);
2256
- auto ret_pb = chained->ret_var (); // extract
2254
+
2255
+ auto ret_pb = chained->var (chained->num_vars () - 1 );
2257
2256
type_dump (world_," ret var pb" ,ret_pb);
2258
2257
auto chain_pb = chain (ret_pb,arg_pb);
2259
2258
type_dump (world_," chain pb" ,chain_pb);
0 commit comments