Skip to content

Commit 584dcc1

Browse files
refactoring j_wrap_rop
1 parent 12cf51d commit 584dcc1

File tree

1 file changed

+12
-36
lines changed

1 file changed

+12
-36
lines changed

thorin/pass/rw/auto_diff.cpp

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,6 +2529,7 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
25292529
// pullbacks of the arguments
25302530
auto apb = pullbacks_[a];
25312531
auto bpb = pullbacks_[b];
2532+
const Def* dst;
25322533
// compute the pullback for each operation
25332534
// general procedure:
25342535
// pb computes a*(...) continues in mid
@@ -2540,24 +2541,11 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
25402541
switch (op) {
25412542
// ∇(a + b) = λz.∂a(z * (1 + 0)) + ∂b(z * (0 + 1))
25422543
case ROp::add: {
2543-
auto dst = world_.op(ROp::add, (nat_t)0, a, b);
2544+
dst = world_.op(ROp::add, (nat_t)0, a, b);
25442545
pb->set_dbg(world_.dbg(pb->name() + "+"));
25452546

25462547
pb->set_body(world_.app(apb, {pb->mem_var(), pb->var(1), middle}));
25472548
middle->set_body(world_.app(bpb, {middle->mem_var(), pb->var(1), end}));
2548-
// auto adiff = middle->var(1);
2549-
// auto bdiff = end->var(1);
2550-
auto adiff = world_.tuple(vars_without_mem_cont(world_,middle));
2551-
auto bdiff = world_.tuple(vars_without_mem_cont(world_,end));
2552-
2553-
2554-
// auto [smem, sum] = vec_add(world_, end->mem_var(), adiff, bdiff);
2555-
// end->set_body(world_.app(pb->ret_var(), flat_tuple({ smem, sum})));
2556-
auto sum_pb=vec_add(world_,adiff,bdiff,pb->ret_var());
2557-
end->set_body(world_.app(sum_pb, end->mem_var()));
2558-
pullbacks_[dst] = pb;
2559-
2560-
return dst;
25612549
}
25622550
// ∇(a - b) = λz.∂a(z * (0 + 1)) - ∂b(z * (0 + 1))
25632551
case ROp::sub: {
@@ -2569,19 +2557,13 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
25692557
// ret(x+y)
25702558
//
25712559
// a*(z)+b*(-z)
2572-
auto dst = world_.op(ROp::sub, (nat_t)0, a, b);
2560+
dst = world_.op(ROp::sub, (nat_t)0, a, b);
25732561
pb->set_dbg(world_.dbg(pb->name() + "-"));
25742562

25752563
pb->set_body(world_.app(apb, {pb->mem_var(), pb->var(1), middle}));
25762564
auto [rmem,one] = ONE(world_,middle->mem_var(), o_type);
25772565
middle->set_body(world_.app(bpb, {rmem, world_.op(ROp::mul, (nat_t)0, pb->var(1), world_.op_rminus((nat_t)0, one)), end}));
25782566
// all args 1..n as tuple => vector for addition
2579-
auto adiff = world_.tuple(vars_without_mem_cont(world_,middle));
2580-
auto bdiff = world_.tuple(vars_without_mem_cont(world_,end));
2581-
auto sum_pb=vec_add(world_,adiff,bdiff,pb->ret_var());
2582-
end->set_body(world_.app(sum_pb, end->mem_var()));
2583-
pullbacks_[dst] = pb;
2584-
return dst;
25852567
}
25862568
// ∇(a * b) = λz.∂a(z * (1 * b + a * 0)) + ∂b(z * (0 * b + a * 1))
25872569
// potential opt: if ∂a = ∂b, do: ∂a(z * (a + b))
@@ -2597,41 +2579,35 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
25972579
// ret(x+y)
25982580
//
25992581
// a*(zb)+b*(za)
2600-
auto dst = world_.op(ROp::mul, (nat_t)0, a, b);
2582+
dst = world_.op(ROp::mul, (nat_t)0, a, b);
26012583
pb->set_dbg(world_.dbg(pb->name() + "*"));
26022584

26032585
pb->set_body(world_.app(apb, {pb->mem_var(), world_.op(ROp::mul, (nat_t)0, pb->var(1), b), middle}));
26042586
middle->set_body(world_.app(bpb, {middle->mem_var(), world_.op(ROp::mul, (nat_t)0, pb->var(1), a), end}));
2605-
auto adiff = world_.tuple(vars_without_mem_cont(world_,middle));
2606-
auto bdiff = world_.tuple(vars_without_mem_cont(world_,end));
2607-
2608-
auto sum_pb=vec_add(world_,adiff,bdiff,pb->ret_var());
2609-
end->set_body(world_.app(sum_pb, end->mem_var()));
2610-
pullbacks_[dst] = pb;
2611-
return dst;
26122587
}
26132588
// ∇(a / b) = λz. (g* (z * h) - h* (z * g))/h²
26142589
case ROp::div: {
26152590
// a*(1/b * z) => a*(z/b)
26162591
// + b*(a * -b^(-2) * z) => b*(-z*a/(b*b))
2617-
auto dst = world_.op(ROp::div, (nat_t)0, a, b);
2592+
dst = world_.op(ROp::div, (nat_t)0, a, b);
26182593
pb->set_dbg(world_.dbg(pb->name() + "/"));
26192594

26202595
pb->set_body(world_.app(apb, {pb->mem_var(), world_.op(ROp::div, (nat_t)0, pb->var(1), b), middle}));
26212596
auto za=world_.op(ROp::mul, (nat_t)0, pb->var(1), a);
26222597
auto bsq=world_.op(ROp::mul, (nat_t)0, b, b);
26232598
middle->set_body(world_.app(bpb, {middle->mem_var(), world_.op_rminus((nat_t)0, world_.op(ROp::div, (nat_t)0, za, bsq)), end}));
2624-
auto adiff = world_.tuple(vars_without_mem_cont(world_,middle));
2625-
auto bdiff = world_.tuple(vars_without_mem_cont(world_,end));
2626-
auto sum_pb=vec_add(world_,adiff,bdiff,pb->ret_var());
2627-
end->set_body(world_.app(sum_pb, end->mem_var()));
2628-
pullbacks_[dst] = pb;
2629-
return dst;
26302599
}
26312600
default:
26322601
// only +, -, *, / are implemented as basic operations
26332602
THORIN_UNREACHABLE;
26342603
}
2604+
2605+
auto adiff = world_.tuple(vars_without_mem_cont(world_,middle));
2606+
auto bdiff = world_.tuple(vars_without_mem_cont(world_,end));
2607+
auto sum_pb=vec_add(world_,adiff,bdiff,pb->ret_var());
2608+
end->set_body(world_.app(sum_pb, end->mem_var()));
2609+
pullbacks_[dst] = pb;
2610+
return dst;
26352611
}
26362612

26372613
// seen is a simple lookup in the src_to_dst mapping

0 commit comments

Comments
 (0)