Skip to content

Commit b270066

Browse files
bugfix rop and implementation AutoDiff definition of isReturning
1 parent 584dcc1 commit b270066

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

thorin/pass/rw/auto_diff.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,14 @@ class AutoDiffer {
589589
const Def* chain(const Def* a, const Def* b);
590590
const Pi* createPbType(const Def* A, const Def* B);
591591
const Def* extract_pb(const Def* j_extract, const Def* tuple);
592+
const Def* isReturning(const Pi* def){
593+
if (def->is_cn() && def->num_doms() > 0) {
594+
auto ret = def->dom(def->num_doms() - 1);
595+
if (auto pi = ret->isa<Pi>(); pi != nullptr && pi->is_cn()) return pi;
596+
}
597+
598+
return nullptr;
599+
}
592600

593601
World& world_;
594602
Def2Def src_to_dst_; // mapping old def to new def
@@ -1507,15 +1515,7 @@ const Def* AutoDiffer::j_wrap(const Def* def) {
15071515

15081516
auto last_mem=current_mem;
15091517

1510-
// auto back_order=-1;//lam->type()->as<Pi>()->doms().back()->order();
1511-
// auto back_order = lam->type()->as<Pi>()->doms().back()->
1512-
// auto returning = back_order>0;
1513-
auto returning = lam->type()->is_returning();
1514-
dlog(world_," lam ret pi: {}", lam->type()->ret_pi() ? 1 : 0);
1515-
// dlog(world_," lam returning2: {}", returning);
1516-
// dlog(world_," order: {}", back_order);
1517-
dlog(world_," back: {}", lam->type()->as<Pi>()->doms().back());
1518-
if(lam->type()->ret_pi() || returning) {
1518+
if( isReturning(lam->type())) {
15191519
auto dst = world_.op_rev_diff(lam);
15201520
type_dump(world_," new lam",dst);
15211521
// THORIN_UNREACHABLE;
@@ -2101,8 +2101,7 @@ const Def* AutoDiffer::j_wrap(const Def* def) {
21012101

21022102
// auto back_order=-1;//callee->type()->as<Pi>()->doms().back()->order();
21032103
// auto returning = back_order>0;
2104-
auto returning = callee->type()->as<Pi>()->is_returning();
2105-
if (callee->type()->as<Pi>()->ret_pi() || returning) {
2104+
if (isReturning(callee->type()->as<Pi>())) {
21062105
dlog(world_," FYI returning callee");
21072106

21082107
const Def* dst_callee;
@@ -2546,6 +2545,7 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
25462545

25472546
pb->set_body(world_.app(apb, {pb->mem_var(), pb->var(1), middle}));
25482547
middle->set_body(world_.app(bpb, {middle->mem_var(), pb->var(1), end}));
2548+
break;
25492549
}
25502550
// ∇(a - b) = λz.∂a(z * (0 + 1)) - ∂b(z * (0 + 1))
25512551
case ROp::sub: {
@@ -2564,6 +2564,8 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
25642564
auto [rmem,one] = ONE(world_,middle->mem_var(), o_type);
25652565
middle->set_body(world_.app(bpb, {rmem, world_.op(ROp::mul, (nat_t)0, pb->var(1), world_.op_rminus((nat_t)0, one)), end}));
25662566
// all args 1..n as tuple => vector for addition
2567+
break;
2568+
25672569
}
25682570
// ∇(a * b) = λz.∂a(z * (1 * b + a * 0)) + ∂b(z * (0 * b + a * 1))
25692571
// potential opt: if ∂a = ∂b, do: ∂a(z * (a + b))
@@ -2584,6 +2586,8 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
25842586

25852587
pb->set_body(world_.app(apb, {pb->mem_var(), world_.op(ROp::mul, (nat_t)0, pb->var(1), b), middle}));
25862588
middle->set_body(world_.app(bpb, {middle->mem_var(), world_.op(ROp::mul, (nat_t)0, pb->var(1), a), end}));
2589+
break;
2590+
25872591
}
25882592
// ∇(a / b) = λz. (g* (z * h) - h* (z * g))/h²
25892593
case ROp::div: {
@@ -2596,6 +2600,7 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
25962600
auto za=world_.op(ROp::mul, (nat_t)0, pb->var(1), a);
25972601
auto bsq=world_.op(ROp::mul, (nat_t)0, b, b);
25982602
middle->set_body(world_.app(bpb, {middle->mem_var(), world_.op_rminus((nat_t)0, world_.op(ROp::div, (nat_t)0, za, bsq)), end}));
2603+
break;
25992604
}
26002605
default:
26012606
// only +, -, *, / are implemented as basic operations

0 commit comments

Comments
 (0)