@@ -540,12 +540,25 @@ static void get_allparams(vector<vector<Params*>> &allparams, vector<Network> &n
540
540
}
541
541
}
542
542
543
- void share_deltas (vector<Network> networks) {
543
+ void distribute_weights (vector<Network> & networks, int from ) {
544
544
vector<vector<Params*>> allparams;
545
545
get_allparams (allparams, networks);
546
546
int n = allparams.size ();
547
547
int m = allparams[0 ].size ();
548
- for (int i=1 ; n; i++) {
548
+ for (int i=0 ; i<n; i++) {
549
+ if (i==from) continue ;
550
+ for (int j=0 ; j<m; j++) {
551
+ allparams[i][j]->V () = allparams[from][j]->V ();
552
+ }
553
+ }
554
+ }
555
+
556
+ void share_deltas (vector<Network> &networks) {
557
+ vector<vector<Params*>> allparams;
558
+ get_allparams (allparams, networks);
559
+ int n = allparams.size ();
560
+ int m = allparams[0 ].size ();
561
+ for (int i=1 ; i<n; i++) {
549
562
for (int j=0 ; j<m; j++) {
550
563
allparams[0 ][j]->D () += allparams[i][j]->D ();
551
564
}
@@ -555,7 +568,7 @@ void share_deltas(vector<Network> networks) {
555
568
}
556
569
}
557
570
558
- void average_weights (vector<Network> networks) {
571
+ void average_weights (vector<Network> & networks) {
559
572
vector<vector<Params*>> allparams;
560
573
get_allparams (allparams, networks);
561
574
int n = allparams.size ();
@@ -564,13 +577,11 @@ void average_weights(vector<Network> networks) {
564
577
for (int j=0 ; j<m; j++) {
565
578
allparams[0 ][j]->V () += allparams[i][j]->V ();
566
579
}
567
- for (int j=0 ; j<m; j++) {
568
- allparams[0 ][j]->V () = allparams[0 ][j]->V () * Float (1.0 /n);
569
- }
570
- for (int j=0 ; j<m; j++) {
571
- allparams[i][j]->V () = allparams[0 ][j]->V ();
572
- }
573
580
}
581
+ for (int j=0 ; j<m; j++) {
582
+ allparams[0 ][j]->V () = allparams[0 ][j]->V () * Float (1.0 /n);
583
+ }
584
+ distribute_weights (networks);
574
585
}
575
586
576
587
} // namespace ocropus
0 commit comments