@@ -589,6 +589,14 @@ class AutoDiffer {
589
589
const Def* chain (const Def* a, const Def* b);
590
590
const Pi* createPbType (const Def* A, const Def* B);
591
591
const Def* extract_pb (const Def* j_extract, const Def* tuple);
592
+ const Def* isReturning (const Pi* def){
593
+ if (def->is_cn () && def->num_doms () > 0 ) {
594
+ auto ret = def->dom (def->num_doms () - 1 );
595
+ if (auto pi = ret->isa <Pi>(); pi != nullptr && pi->is_cn ()) return pi;
596
+ }
597
+
598
+ return nullptr ;
599
+ }
592
600
593
601
World& world_;
594
602
Def2Def src_to_dst_; // mapping old def to new def
@@ -1507,15 +1515,7 @@ const Def* AutoDiffer::j_wrap(const Def* def) {
1507
1515
1508
1516
auto last_mem=current_mem;
1509
1517
1510
- // auto back_order=-1;//lam->type()->as<Pi>()->doms().back()->order();
1511
- // auto back_order = lam->type()->as<Pi>()->doms().back()->
1512
- // auto returning = back_order>0;
1513
- auto returning = lam->type ()->is_returning ();
1514
- dlog (world_," lam ret pi: {}" , lam->type ()->ret_pi () ? 1 : 0 );
1515
- // dlog(world_," lam returning2: {}", returning);
1516
- // dlog(world_," order: {}", back_order);
1517
- dlog (world_," back: {}" , lam->type ()->as <Pi>()->doms ().back ());
1518
- if (lam->type ()->ret_pi () || returning) {
1518
+ if ( isReturning (lam->type ())) {
1519
1519
auto dst = world_.op_rev_diff (lam);
1520
1520
type_dump (world_," new lam" ,dst);
1521
1521
// THORIN_UNREACHABLE;
@@ -2101,8 +2101,7 @@ const Def* AutoDiffer::j_wrap(const Def* def) {
2101
2101
2102
2102
// auto back_order=-1;//callee->type()->as<Pi>()->doms().back()->order();
2103
2103
// auto returning = back_order>0;
2104
- auto returning = callee->type ()->as <Pi>()->is_returning ();
2105
- if (callee->type ()->as <Pi>()->ret_pi () || returning) {
2104
+ if (isReturning (callee->type ()->as <Pi>())) {
2106
2105
dlog (world_," FYI returning callee" );
2107
2106
2108
2107
const Def* dst_callee;
@@ -2546,6 +2545,7 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
2546
2545
2547
2546
pb->set_body (world_.app (apb, {pb->mem_var (), pb->var (1 ), middle}));
2548
2547
middle->set_body (world_.app (bpb, {middle->mem_var (), pb->var (1 ), end}));
2548
+ break ;
2549
2549
}
2550
2550
// ∇(a - b) = λz.∂a(z * (0 + 1)) - ∂b(z * (0 + 1))
2551
2551
case ROp::sub: {
@@ -2564,6 +2564,8 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
2564
2564
auto [rmem,one] = ONE (world_,middle->mem_var (), o_type);
2565
2565
middle->set_body (world_.app (bpb, {rmem, world_.op (ROp::mul, (nat_t )0 , pb->var (1 ), world_.op_rminus ((nat_t )0 , one)), end}));
2566
2566
// all args 1..n as tuple => vector for addition
2567
+ break ;
2568
+
2567
2569
}
2568
2570
// ∇(a * b) = λz.∂a(z * (1 * b + a * 0)) + ∂b(z * (0 * b + a * 1))
2569
2571
// potential opt: if ∂a = ∂b, do: ∂a(z * (a + b))
@@ -2584,6 +2586,8 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
2584
2586
2585
2587
pb->set_body (world_.app (apb, {pb->mem_var (), world_.op (ROp::mul, (nat_t )0 , pb->var (1 ), b), middle}));
2586
2588
middle->set_body (world_.app (bpb, {middle->mem_var (), world_.op (ROp::mul, (nat_t )0 , pb->var (1 ), a), end}));
2589
+ break ;
2590
+
2587
2591
}
2588
2592
// ∇(a / b) = λz. (g* (z * h) - h* (z * g))/h²
2589
2593
case ROp::div: {
@@ -2596,6 +2600,7 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
2596
2600
auto za=world_.op (ROp::mul, (nat_t )0 , pb->var (1 ), a);
2597
2601
auto bsq=world_.op (ROp::mul, (nat_t )0 , b, b);
2598
2602
middle->set_body (world_.app (bpb, {middle->mem_var (), world_.op_rminus ((nat_t )0 , world_.op (ROp::div, (nat_t )0 , za, bsq)), end}));
2603
+ break ;
2599
2604
}
2600
2605
default :
2601
2606
// only +, -, *, / are implemented as basic operations
0 commit comments