Skip to content

Commit 3187840

Browse files
fix ret_var
1 parent b270066 commit 3187840

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

thorin/pass/rw/auto_diff.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,25 @@ static const Def* to_fat_ptr(World& world, const Def* ptr, const Def* size){
8686
auto dst_fat_ptr=world.tuple({int_size, ptr});
8787
}
8888

89-
Array<const Def*> vars_without_mem_cont(World& world, Lam* lam) {
89+
const Pi* isReturning(const Pi* pi){
90+
if (pi->is_cn() && pi->num_doms() > 0) {
91+
auto ret = pi->dom(pi->num_doms() - 1);
92+
if (auto ret_pi = ret->isa<Pi>(); ret_pi != nullptr && ret_pi->is_cn()) return ret_pi;
93+
}
94+
95+
return nullptr;
96+
}
97+
98+
DefArray vars_without_mem_cont(World& world, Lam* lam) {
9099
type_dump(world," get vars of",lam);
91100
dlog(world," has ret_var {}",lam->ret_var());
92101
// if(lam->ret_var())
93-
return Array<const Def*>(
94-
lam->num_vars()-(lam->ret_var()==nullptr ? 1 : 2),
102+
return {
103+
lam->num_vars()-( isReturning(lam->type()) == nullptr ? 1 : 2),
95104
[&](auto i) {
96105
return lam->var(i+1);
97-
});
106+
}
107+
};
98108
}
99109

100110

@@ -527,10 +537,6 @@ std::pair<const Def*,const Def*> oneHot(World& world_, const Def* mem, const Def
527537
}
528538
}
529539

530-
531-
532-
533-
534540
namespace {
535541

536542
class AutoDiffer {
@@ -589,14 +595,6 @@ class AutoDiffer {
589595
const Def* chain(const Def* a, const Def* b);
590596
const Pi* createPbType(const Def* A, const Def* B);
591597
const Def* extract_pb(const Def* j_extract, const Def* tuple);
592-
const Def* isReturning(const Pi* def){
593-
if (def->is_cn() && def->num_doms() > 0) {
594-
auto ret = def->dom(def->num_doms() - 1);
595-
if (auto pi = ret->isa<Pi>(); pi != nullptr && pi->is_cn()) return pi;
596-
}
597-
598-
return nullptr;
599-
}
600598

601599
World& world_;
602600
Def2Def src_to_dst_; // mapping old def to new def
@@ -2253,7 +2251,8 @@ const Def* AutoDiffer::j_wrap(const Def* def) {
22532251

22542252
auto arg_pb = pullbacks_[d_arg]; // Lam
22552253
type_dump(world_," arg pb",arg_pb);
2256-
auto ret_pb = chained->ret_var(); // extract
2254+
2255+
auto ret_pb = chained->var(chained->num_vars() - 1);
22572256
type_dump(world_," ret var pb",ret_pb);
22582257
auto chain_pb = chain(ret_pb,arg_pb);
22592258
type_dump(world_," chain pb",chain_pb);

0 commit comments

Comments
 (0)