@@ -2529,6 +2529,7 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
2529
2529
// pullbacks of the arguments
2530
2530
auto apb = pullbacks_[a];
2531
2531
auto bpb = pullbacks_[b];
2532
+ const Def* dst;
2532
2533
// compute the pullback for each operation
2533
2534
// general procedure:
2534
2535
// pb computes a*(...) continues in mid
@@ -2540,24 +2541,11 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
2540
2541
switch (op) {
2541
2542
// ∇(a + b) = λz.∂a(z * (1 + 0)) + ∂b(z * (0 + 1))
2542
2543
case ROp::add: {
2543
- auto dst = world_.op (ROp::add, (nat_t )0 , a, b);
2544
+ dst = world_.op (ROp::add, (nat_t )0 , a, b);
2544
2545
pb->set_dbg (world_.dbg (pb->name () + " +" ));
2545
2546
2546
2547
pb->set_body (world_.app (apb, {pb->mem_var (), pb->var (1 ), middle}));
2547
2548
middle->set_body (world_.app (bpb, {middle->mem_var (), pb->var (1 ), end}));
2548
- // auto adiff = middle->var(1);
2549
- // auto bdiff = end->var(1);
2550
- auto adiff = world_.tuple (vars_without_mem_cont (world_,middle));
2551
- auto bdiff = world_.tuple (vars_without_mem_cont (world_,end));
2552
-
2553
-
2554
- // auto [smem, sum] = vec_add(world_, end->mem_var(), adiff, bdiff);
2555
- // end->set_body(world_.app(pb->ret_var(), flat_tuple({ smem, sum})));
2556
- auto sum_pb=vec_add (world_,adiff,bdiff,pb->ret_var ());
2557
- end->set_body (world_.app (sum_pb, end->mem_var ()));
2558
- pullbacks_[dst] = pb;
2559
-
2560
- return dst;
2561
2549
}
2562
2550
// ∇(a - b) = λz.∂a(z * (0 + 1)) - ∂b(z * (0 + 1))
2563
2551
case ROp::sub: {
@@ -2569,19 +2557,13 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
2569
2557
// ret(x+y)
2570
2558
//
2571
2559
// a*(z)+b*(-z)
2572
- auto dst = world_.op (ROp::sub, (nat_t )0 , a, b);
2560
+ dst = world_.op (ROp::sub, (nat_t )0 , a, b);
2573
2561
pb->set_dbg (world_.dbg (pb->name () + " -" ));
2574
2562
2575
2563
pb->set_body (world_.app (apb, {pb->mem_var (), pb->var (1 ), middle}));
2576
2564
auto [rmem,one] = ONE (world_,middle->mem_var (), o_type);
2577
2565
middle->set_body (world_.app (bpb, {rmem, world_.op (ROp::mul, (nat_t )0 , pb->var (1 ), world_.op_rminus ((nat_t )0 , one)), end}));
2578
2566
// all args 1..n as tuple => vector for addition
2579
- auto adiff = world_.tuple (vars_without_mem_cont (world_,middle));
2580
- auto bdiff = world_.tuple (vars_without_mem_cont (world_,end));
2581
- auto sum_pb=vec_add (world_,adiff,bdiff,pb->ret_var ());
2582
- end->set_body (world_.app (sum_pb, end->mem_var ()));
2583
- pullbacks_[dst] = pb;
2584
- return dst;
2585
2567
}
2586
2568
// ∇(a * b) = λz.∂a(z * (1 * b + a * 0)) + ∂b(z * (0 * b + a * 1))
2587
2569
// potential opt: if ∂a = ∂b, do: ∂a(z * (a + b))
@@ -2597,41 +2579,35 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
2597
2579
// ret(x+y)
2598
2580
//
2599
2581
// a*(zb)+b*(za)
2600
- auto dst = world_.op (ROp::mul, (nat_t )0 , a, b);
2582
+ dst = world_.op (ROp::mul, (nat_t )0 , a, b);
2601
2583
pb->set_dbg (world_.dbg (pb->name () + " *" ));
2602
2584
2603
2585
pb->set_body (world_.app (apb, {pb->mem_var (), world_.op (ROp::mul, (nat_t )0 , pb->var (1 ), b), middle}));
2604
2586
middle->set_body (world_.app (bpb, {middle->mem_var (), world_.op (ROp::mul, (nat_t )0 , pb->var (1 ), a), end}));
2605
- auto adiff = world_.tuple (vars_without_mem_cont (world_,middle));
2606
- auto bdiff = world_.tuple (vars_without_mem_cont (world_,end));
2607
-
2608
- auto sum_pb=vec_add (world_,adiff,bdiff,pb->ret_var ());
2609
- end->set_body (world_.app (sum_pb, end->mem_var ()));
2610
- pullbacks_[dst] = pb;
2611
- return dst;
2612
2587
}
2613
2588
// ∇(a / b) = λz. (g* (z * h) - h* (z * g))/h²
2614
2589
case ROp::div: {
2615
2590
// a*(1/b * z) => a*(z/b)
2616
2591
// + b*(a * -b^(-2) * z) => b*(-z*a/(b*b))
2617
- auto dst = world_.op (ROp::div, (nat_t )0 , a, b);
2592
+ dst = world_.op (ROp::div, (nat_t )0 , a, b);
2618
2593
pb->set_dbg (world_.dbg (pb->name () + " /" ));
2619
2594
2620
2595
pb->set_body (world_.app (apb, {pb->mem_var (), world_.op (ROp::div, (nat_t )0 , pb->var (1 ), b), middle}));
2621
2596
auto za=world_.op (ROp::mul, (nat_t )0 , pb->var (1 ), a);
2622
2597
auto bsq=world_.op (ROp::mul, (nat_t )0 , b, b);
2623
2598
middle->set_body (world_.app (bpb, {middle->mem_var (), world_.op_rminus ((nat_t )0 , world_.op (ROp::div, (nat_t )0 , za, bsq)), end}));
2624
- auto adiff = world_.tuple (vars_without_mem_cont (world_,middle));
2625
- auto bdiff = world_.tuple (vars_without_mem_cont (world_,end));
2626
- auto sum_pb=vec_add (world_,adiff,bdiff,pb->ret_var ());
2627
- end->set_body (world_.app (sum_pb, end->mem_var ()));
2628
- pullbacks_[dst] = pb;
2629
- return dst;
2630
2599
}
2631
2600
default :
2632
2601
// only +, -, *, / are implemented as basic operations
2633
2602
THORIN_UNREACHABLE;
2634
2603
}
2604
+
2605
+ auto adiff = world_.tuple (vars_without_mem_cont (world_,middle));
2606
+ auto bdiff = world_.tuple (vars_without_mem_cont (world_,end));
2607
+ auto sum_pb=vec_add (world_,adiff,bdiff,pb->ret_var ());
2608
+ end->set_body (world_.app (sum_pb, end->mem_var ()));
2609
+ pullbacks_[dst] = pb;
2610
+ return dst;
2635
2611
}
2636
2612
2637
2613
// seen is a simple lookup in the src_to_dst mapping
0 commit comments