Skip to content

Commit 10dd7e5

Browse files
refactoring DefArray's
1 parent 1dec8a9 commit 10dd7e5

File tree

2 files changed

+71
-97
lines changed

2 files changed

+71
-97
lines changed

thorin/pass/rw/auto_diff.cpp

Lines changed: 55 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,8 @@ const Pi* isReturning(const Pi* pi){
6969
return nullptr;
7070
}
7171

72-
DefArray vars_without_mem_cont(World& world, Lam* lam) {
73-
return {
74-
lam->num_vars()-( isReturning(lam->type()) == nullptr ? 1 : 2),
75-
[&](auto i) {
76-
return lam->var(i+1);
77-
}
78-
};
72+
DefArray vars_without_mem_cont(Lam* lam) {
73+
return lam->vars().skip(1, isReturning(lam->type()) != nullptr);
7974
}
8075
// multidimensional addition of values
8176
// needed for operation differentiation
@@ -100,7 +95,7 @@ const Lam* vec_add(World& world, const Def* a, const Def* b, const Def* cont) {
10095
auto sum_cont = vec_add(world,a_v,b_v,res_cont);
10196
sum_pb->set_body(world.app(sum_cont, mem3));
10297
auto rmem=res_cont->mem_var();
103-
auto s_v= world.tuple(vars_without_mem_cont(world,res_cont));
98+
auto s_v= world.tuple(vars_without_mem_cont(res_cont));
10499
auto [rmem2, sum_ptr]=world.op_slot(ty,rmem,world.dbg("add_slot"))->projs<2>();
105100
auto rmem3 = world.op_store(rmem2,sum_ptr,s_v);
106101

@@ -151,7 +146,7 @@ const Lam* vec_add(World& world, const Def* a, const Def* b, const Def* cont) {
151146
auto elem_res_cont_type = world.cn_mem_flat(a_v->type());
152147
auto elem_res_cont = world.nom_filter_lam(elem_res_cont_type,world.dbg("tuple_add_cont"));
153148
auto element_sum_pb = vec_add(world,a_v,b_v,elem_res_cont);
154-
auto c_v = world.tuple(vars_without_mem_cont(world,elem_res_cont));
149+
auto c_v = world.tuple(vars_without_mem_cont(elem_res_cont));
155150
auto res_mem=elem_res_cont->mem_var();
156151
res_mem=world.op_store(res_mem,c_p,c_v);
157152

@@ -208,7 +203,7 @@ const Lam* vec_add(World& world, const Def* a, const Def* b, const Def* cont) {
208203
auto res_cont_type = world.cn_mem_flat(ai->type());
209204
auto res_cont = world.nom_filter_lam(res_cont_type,world.dbg("tuple_add_cont"));
210205
auto sum_call=vec_add(world,ai,bi,res_cont);
211-
ops[i]=world.tuple(vars_without_mem_cont(world,res_cont));
206+
ops[i]=world.tuple(vars_without_mem_cont(res_cont));
212207

213208
current_cont->set_body(world.app(
214209
sum_call,
@@ -278,22 +273,19 @@ std::pair<const Def*,const Def*> lit_of_type(World& world, const Def* mem, const
278273
litdef= world.lit_real(as_lit(real->arg()), lit);
279274
else if (auto a = type->isa<Arr>()) {
280275
auto dim = a->shape()->as<Lit>()->get<uint8_t>();
281-
DefArray ops{dim};
282-
for (size_t i = 0; i < dim; ++i) {
283-
auto [nmem, op]=lit_of_type(world,mem,a->body(),like,lit,dummy);
276+
DefArray ops{dim, [&](auto){
277+
auto [nmem, op] = lit_of_type(world,mem,a->body(),like,lit,dummy);
284278
mem=nmem;
285-
ops[i]=op;
286-
}
279+
return op;
280+
}};
287281
litdef= world.tuple(ops);
288282
}else if(auto sig = type->isa<Sigma>()) {
289-
std::vector<const Def*> zops;
290-
int idx=0;
291-
for (auto op : sig->ops()) {
292-
auto [nmem, zop]=lit_of_type(world,mem,op,like->proj(idx),lit,dummy);
283+
auto zops = sig->ops().map([&](auto op, auto index){
284+
auto [nmem, zop]=lit_of_type(world,mem,op,like->proj(index),lit,dummy);
293285
mem=nmem;
294-
zops.push_back(zop);
295-
idx++;
296-
}
286+
return zop;
287+
});
288+
297289
litdef= world.tuple(zops);
298290
}
299291
else litdef= dummy;
@@ -447,22 +439,12 @@ const Def* AutoDiffer::j_wrap_tuple(DefArray tuple) {
447439
// apply them with the component of the scalar from the tuple pullback
448440
// sum them up
449441

450-
size_t real_arg_num;
451-
if(isRetTuple)
452-
real_arg_num=tuple_dim-2;
453-
else if(isMemTuple)
454-
real_arg_num=tuple_dim-1;
455-
else
456-
real_arg_num=tuple_dim;
457-
458-
// const Def* trimmed_ty;
459-
// auto tuple_ty = tuple->type();
460-
auto trimmed_var_ty=DefArray(real_arg_num,
461-
[&] (auto i) {
462-
return tuple[isMemTuple ? i+1 : i]->type();
463-
});
464-
465-
auto trimmed_ty=world_.sigma(trimmed_var_ty);
442+
auto trimmed_tuple = tuple.skip(isMemTuple, isRetTuple);
443+
auto trimed_ops = ops.skip(isMemTuple, isRetTuple);
444+
445+
auto trimmed_ty=world_.sigma(
446+
trimmed_tuple.map( [] (auto* def, auto) { return def->type(); } )
447+
);
466448
auto pi = createPbType(A,trimmed_ty);
467449
auto pb = world_.nom_filter_lam(pi, world_.dbg("tuple_pb"));
468450
auto pbT = pi->as<Pi>()->doms().back()->as<Pi>();
@@ -472,27 +454,18 @@ const Def* AutoDiffer::j_wrap_tuple(DefArray tuple) {
472454
flat_tuple({
473455
pb->mem_var(),
474456
zero_grad
475-
}) ));
476-
477-
auto tuple_of_pb = world_.tuple(
478-
DefArray{real_arg_num, [&](auto i) { return pullbacks_[isMemTuple ? ops[i+1] : ops[i]]; }}
479-
);
457+
})
458+
));
480459

481460
/**
482461
* pb = \lambda mem scalars ret. sum_pb_0 (mem,0)
483462
* sum_pb_i = \lambda mem sum_i. pb_i (mem, s_i, res_pb_i)
484463
* res_pb_i = \lambda mem res_i. sum_cont (mem, sum_i, res_i, sum_pb_{i+1})
485464
* sum_pb_n = \lambda mem sum. ret (mem, sum)
486465
*/
487-
for (size_t i = 0; i < real_arg_num; ++i) {
488-
489-
const Def* op;
490-
if(isMemTuple) {
491-
op=ops[i+1];
492-
}else {
493-
op=ops[i];
494-
}
495-
auto op_pb=pullbacks_[op];
466+
for (size_t i = 0; i < trimed_ops.size(); ++i) {
467+
const Def* op = trimed_ops[i];
468+
auto op_pb = pullbacks_[op];
496469
auto scalar = pb->var(i+1, world_.dbg("s"));
497470

498471
auto res_pb = world_.nom_filter_lam(pbT, world_.dbg("res_pb"));
@@ -502,13 +475,14 @@ const Def* AutoDiffer::j_wrap_tuple(DefArray tuple) {
502475
current_sum_pb->mem_var(),
503476
scalar,
504477
res_pb
505-
})));
478+
})
479+
));
506480

507481
auto next_current_sum_pb = world_.nom_filter_lam(pbT, world_.dbg("tuple_sum_pb"));
508482

509483
auto sum_cont_pb = vec_add(world_,
510-
world_.tuple(vars_without_mem_cont(world_,current_sum_pb)),
511-
world_.tuple(vars_without_mem_cont(world_,res_pb)),
484+
world_.tuple(vars_without_mem_cont(current_sum_pb)),
485+
world_.tuple(vars_without_mem_cont(res_pb)),
512486
next_current_sum_pb);
513487
res_pb->set_body(world_.app(
514488
sum_cont_pb,
@@ -546,8 +520,8 @@ const Def* AutoDiffer::chain(const Def* a, const Def* b) {
546520
auto middlepi = world_.cn_mem_flat(B);
547521
auto middle = world_.nom_filter_lam(middlepi, world_.dbg("chain_2"));
548522

549-
toplevel->set_body(world_.app(a, flat_tuple({toplevel->mem_var(), world_.tuple(vars_without_mem_cont(world_,toplevel)), middle})));
550-
middle->set_body(world_.app(b, flat_tuple({middle->mem_var(), world_.tuple(vars_without_mem_cont(world_,middle)), toplevel->ret_var()})));
523+
toplevel->set_body(world_.app(a, flat_tuple({toplevel->mem_var(), world_.tuple(vars_without_mem_cont(toplevel)), middle})));
524+
middle->set_body(world_.app(b, flat_tuple({middle->mem_var(), world_.tuple(vars_without_mem_cont(middle)), toplevel->ret_var()})));
551525

552526
return toplevel;
553527
}
@@ -602,7 +576,7 @@ const Def* AutoDiffer::extract_pb(const Def* j_extract, const Def* tuple) {
602576
else if(i==dim-1) {
603577
args[i]=pb->ret_var();
604578
} else if(i==index_lit) {
605-
args[i]= world_.tuple(vars_without_mem_cont(world_,pb));
579+
args[i]= world_.tuple(vars_without_mem_cont(pb));
606580
}else {
607581
// TODO: correct index
608582
auto [nmem, v]=ZERO(world_,mem,pb_domain->op(i), tuple->proj(i));
@@ -612,7 +586,6 @@ const Def* AutoDiffer::extract_pb(const Def* j_extract, const Def* tuple) {
612586
}
613587
args[0]=mem;
614588
pb_args=args;
615-
616589
}else {
617590
auto [rmem, ohv] = oneHot(world_,pb->mem_var(), idx,world_.tangent_type(tuple_ty,false),nullptr,pb->var(1,world_.dbg("s")));
618591
pb_args=
@@ -625,7 +598,7 @@ const Def* AutoDiffer::extract_pb(const Def* j_extract, const Def* tuple) {
625598
pb->set_body(world_.app(
626599
tuple_pb,
627600
pb_args
628-
));
601+
));
629602
return pb;
630603
}
631604
// loads pb from shadow slot, updates pb for the ptr, returns, mem and pb for the loaded value
@@ -645,41 +618,23 @@ const Def* AutoDiffer::reverse_diff(Lam* src) {
645618
auto dst_var = src_to_dst_[src_var];
646619
auto var_sigma = src_var->type()->as<Sigma>();
647620

648-
auto size = var_sigma->num_ops() - 2;
649-
DefArray trimmed_var_ty(size);
650-
for (size_t i = 0; i < size; ++i) {
651-
trimmed_var_ty[i] = var_sigma->op(i+1);
652-
}
621+
DefArray trimmed_var_ty = var_sigma->ops().skip();
653622
auto trimmed_var_sigma = world_.sigma(trimmed_var_ty);
654623
auto idpi = createPbType(A,trimmed_var_sigma);
655624
auto idpb = world_.nom_filter_lam(idpi, world_.dbg("param_id"));
656-
auto real_params = DefArray(
657-
dst_lam->num_vars()-2,
658-
[&](auto i) {
659-
return dst_lam->var(i+1);
660-
});
625+
auto real_params = dst_lam->vars().skip();
661626
auto [current_mem_,zero_grad_] = ZERO(world_,current_mem,A,world_.tuple(real_params));
662627
current_mem=current_mem_;
663628
zero_grad=zero_grad_;
664629
// ret only resp. non-mem, non-cont
665-
auto args = DefArray(
666-
src->num_vars()-1,
667-
[&](auto i) {
668-
if(i==0)
669-
return idpb->mem_var();
670-
return idpb->var(i);
671-
});
630+
auto args = idpb->vars().skip_back();
672631
idpb->set_body(world_.app(idpb->ret_var(), args));
673632
pullbacks_[dst_var] = idpb;
674-
for(size_t i = 0, e = src->num_vars(); i < e; ++i) {
675-
auto dvar = dst_lam->var(i);
676-
if(dvar == dst_lam->ret_var() || dvar == dst_lam->mem_var()) {
677-
continue;
678-
}
679-
// solve the problem of inital array pb in extract pb
680-
pullbacks_[dvar]= extract_pb(dvar, dst_lam->var());
681-
initArg(dvar);
682-
}
633+
for(auto dvar : src->vars().skip()) {
634+
// solve the problem of inital array pb in extract pb
635+
pullbacks_[dvar]= extract_pb(dvar, dst_lam->var());
636+
initArg(dvar);
637+
}
683638
// translate the body => get correct applications of variables using pullbacks
684639
auto dst = j_wrap(src->body());
685640
return dst;
@@ -1340,12 +1295,9 @@ const Def* AutoDiffer::j_wrap_convert(const Def* def) {
13401295
auto m = d_arg->proj(0);
13411296
auto num_projs = d_arg->num_projs();
13421297
auto ret_arg = d_arg->proj(num_projs-1);
1343-
auto args=DefArray(
1344-
num_projs-2,
1345-
[&](auto i) {
1346-
return d_arg->proj(i+1);
1347-
});
1348-
auto arg= world_.tuple(args);
1298+
auto arg= world_.tuple(
1299+
d_arg->projs().skip()
1300+
);
13491301
auto pbT = dst_callee->type()->as<Pi>()->doms().back()->as<Pi>();
13501302
auto chained = world_.nom_filter_lam(pbT, world_.dbg("φchain"));
13511303
auto arg_pb = pullbacks_[d_arg]; // Lam
@@ -1356,7 +1308,7 @@ const Def* AutoDiffer::j_wrap_convert(const Def* def) {
13561308
ret_arg,
13571309
flat_tuple({
13581310
chained->mem_var(),
1359-
world_.tuple(vars_without_mem_cont(world_,chained)),
1311+
world_.tuple(vars_without_mem_cont(chained)),
13601312
chain_pb
13611313
})
13621314
));
@@ -1392,7 +1344,13 @@ const Def* AutoDiffer::j_wrap_convert(const Def* def) {
13921344
ad_args = world_.tuple(
13931345
DefArray(
13941346
count+1,
1395-
[&](auto i) {if (i<count) {return world_.extract(d_arg, (u64)i, world_.dbg("ad_arg"));} else {return pullbacks_[d_arg];}}
1347+
[&](auto i) {
1348+
if (i<count) {
1349+
return world_.extract(d_arg, (u64)i, world_.dbg("ad_arg"));
1350+
} else {
1351+
return pullbacks_[d_arg];
1352+
}
1353+
}
13961354
));
13971355
}else {
13981356
// var (lambda completely with all arguments) and other (non tuple)
@@ -1414,7 +1372,7 @@ const Def* AutoDiffer::j_wrap_convert(const Def* def) {
14141372
auto dim = as_lit(pack->type()->arity());
14151373
auto tup=DefArray(
14161374
dim,
1417-
[&](auto i) {
1375+
[&](auto) {
14181376
return pack->body();
14191377
});
14201378
return j_wrap_tuple(tup);
@@ -1555,8 +1513,8 @@ const Def* AutoDiffer::j_wrap_rop(ROp op, const Def* a, const Def* b) {
15551513
THORIN_UNREACHABLE;
15561514
}
15571515

1558-
auto adiff = world_.tuple(vars_without_mem_cont(world_,middle));
1559-
auto bdiff = world_.tuple(vars_without_mem_cont(world_,end));
1516+
auto adiff = world_.tuple(vars_without_mem_cont(middle));
1517+
auto bdiff = world_.tuple(vars_without_mem_cont(end));
15601518
auto sum_pb=vec_add(world_,adiff,bdiff,pb->ret_var());
15611519
end->set_body(world_.app(sum_pb, end->mem_var()));
15621520
pullbacks_[dst] = pb;

thorin/util/array.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class ArrayRef {
112112

113113
/// @name slice
114114
///@{
115+
ArrayRef<T> skip(size_t front = 1, size_t back = 1) const { return ArrayRef<T>(size() - ( front + back ), ptr_ + front); }
115116
ArrayRef<T> skip_front(size_t num = 1) const { return ArrayRef<T>(size() - num, ptr_ + num); }
116117
ArrayRef<T> skip_back(size_t num = 1) const { return ArrayRef<T>(size() - num, ptr_); }
117118
ArrayRef<T> get_front(size_t num = 1) const {
@@ -143,6 +144,20 @@ class ArrayRef {
143144
swap(a1.ptr_, a2.ptr_);
144145
}
145146

147+
template<typename Result = T >
148+
Array<Result> map(std::function<Result(T, size_t)> f){
149+
auto result = Array<Result>(size());
150+
151+
for (size_t i = 0; i < size(); ++i){
152+
result[i] = f((*this)[i], i);
153+
}
154+
155+
return result;
156+
}
157+
158+
Array<T> map(std::function<T(T, size_t)> f){
159+
return map<T>(f);
160+
}
146161
private:
147162
size_t size_;
148163
const T* ptr_;
@@ -349,6 +364,7 @@ class Array {
349364

350365
/// @name slice
351366
///@{
367+
ArrayRef<T> skip(size_t front = 1, size_t back = 1) const { return ArrayRef<T>(size() - ( front + back ), data() + front); }
352368
ArrayRef<T> skip_front(size_t num = 1) const { return ArrayRef<T>(size() - num, data() + num); }
353369
ArrayRef<T> skip_back(size_t num = 1) const { return ArrayRef<T>(size() - num, data()); }
354370
ArrayRef<T> get_front(size_t num = 1) const {

0 commit comments

Comments
 (0)