Skip to content

Commit 484108c

Browse files
committed
Added NDEBUG, functions for merging models.
1 parent fee039e commit 484108c

File tree

5 files changed

+35
-16
lines changed

5 files changed

+35
-16
lines changed

SConstruct

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ elif option("debug", 0)>0:
7777
env.Append(CCFLAGS="-g".split())
7878
env.Append(LINKFLAGS="-g".split())
7979
else:
80-
env.Append(CXXFLAGS="-g -O3 -finline".split())
80+
env.Append(CXXFLAGS="-g -O3 -DNDEBUG -finline".split())
8181
env.Append(CCFLAGS="-g".split())
8282

8383
# Extra layers (old layers or testing)

clstm.cc

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,25 @@ static void get_allparams(vector<vector<Params*>> &allparams, vector<Network> &n
540540
}
541541
}
542542

543-
void share_deltas(vector<Network> networks) {
543+
void distribute_weights(vector<Network> &networks, int from) {
544544
vector<vector<Params*>> allparams;
545545
get_allparams(allparams, networks);
546546
int n = allparams.size();
547547
int m = allparams[0].size();
548-
for(int i=1; n; i++) {
548+
for(int i=0; i<n; i++) {
549+
if (i==from) continue;
550+
for(int j=0; j<m; j++) {
551+
allparams[i][j]->V() = allparams[from][j]->V();
552+
}
553+
}
554+
}
555+
556+
void share_deltas(vector<Network> &networks) {
557+
vector<vector<Params*>> allparams;
558+
get_allparams(allparams, networks);
559+
int n = allparams.size();
560+
int m = allparams[0].size();
561+
for(int i=1; i<n; i++) {
549562
for(int j=0; j<m; j++) {
550563
allparams[0][j]->D() += allparams[i][j]->D();
551564
}
@@ -555,7 +568,7 @@ void share_deltas(vector<Network> networks) {
555568
}
556569
}
557570

558-
void average_weights(vector<Network> networks) {
571+
void average_weights(vector<Network> &networks) {
559572
vector<vector<Params*>> allparams;
560573
get_allparams(allparams, networks);
561574
int n = allparams.size();
@@ -564,13 +577,11 @@ void average_weights(vector<Network> networks) {
564577
for(int j=0; j<m; j++) {
565578
allparams[0][j]->V() += allparams[i][j]->V();
566579
}
567-
for(int j=0; j<m; j++) {
568-
allparams[0][j]->V() = allparams[0][j]->V() * Float(1.0/n);
569-
}
570-
for(int j=0; j<m; j++) {
571-
allparams[i][j]->V() = allparams[0][j]->V();
572-
}
573580
}
581+
for(int j=0; j<m; j++) {
582+
allparams[0][j]->V() = allparams[0][j]->V() * Float(1.0/n);
583+
}
584+
distribute_weights(networks);
574585
}
575586

576587
} // namespace ocropus

clstm.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@ void trivial_decode(Classes &cs, Sequence &outputs, int batch,
205205
// single sequence training functions
206206
void mktargets(Tensor<float,2> &seq, Tensor<int,1> &targets, int ndim);
207207

208-
void share_deltas(vector<Network> networks);
209-
void average_weights(vector<Network> networks);
208+
void share_deltas(vector<Network> &networks);
209+
void average_weights(vector<Network> &networks);
210+
void distribute_weights(vector<Network> &networks, int from=0);
210211
}
211212

212213
namespace {

clstmhl.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ struct CLSTMText {
108108
};
109109

110110
struct CLSTMOCR {
111-
unique_ptr<INormalizer> normalizer;
111+
shared_ptr<INormalizer> normalizer;
112112
Network net;
113113
int target_height = 48;
114114
int nclasses = -1;
@@ -132,7 +132,7 @@ struct CLSTMOCR {
132132
normalizer.reset(make_CenterNormalizer());
133133
normalizer->target_height = target_height;
134134
}
135-
std::wstring train(Tensor<float,2> &raw, const std::wstring &target) {
135+
std::wstring fwdbwd(Tensor<float,2> &raw, const std::wstring &target) {
136136
normalizer->measure(raw);
137137
normalizer->normalize(image, raw);
138138
set_inputs(net, image);
@@ -144,11 +144,18 @@ struct CLSTMOCR {
144144
for (int t = 0; t < aligned.size(); t++)
145145
net->outputs[t].D() = aligned[t].V() - net->outputs[t].V();
146146
net->backward();
147-
sgd_update(net);
148147
Classes outputs;
149148
trivial_decode(outputs, net->outputs);
150149
return net->codec.decode(outputs);
151150
}
151+
void update() {
152+
sgd_update(net);
153+
}
154+
std::wstring train(Tensor<float,2> &raw, const std::wstring &target) {
155+
std::wstring result = fwdbwd(raw, target);
156+
update();
157+
return result;
158+
}
152159
std::string aligned_utf8() {
153160
Classes outputs;
154161
trivial_decode(outputs, aligned);

run-tests

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export PS4='
44
>>>>>>> '
55
trap "echo TEST FAILED" EXIT
66
set -x
7-
export seed=0.1423
7+
export seed=0.7733
88
scons -s -c; rm -f *.o *.a
99
scons -j 4 clstmocrtrain clstmfiltertrain clstmfilter clstmocr test-lstm
1010
./test-lstm

0 commit comments

Comments
 (0)