diff --git a/NNDisParser/src/NNDisParser-synmlp-do/NNDisParser.cpp b/NNDisParser/src/NNDisParser-synmlp-do/NNDisParser.cpp index d6ac62b..0019f3d 100644 --- a/NNDisParser/src/NNDisParser-synmlp-do/NNDisParser.cpp +++ b/NNDisParser/src/NNDisParser-synmlp-do/NNDisParser.cpp @@ -350,10 +350,15 @@ void DisParser::test(const string &testFile, const string &outputFile, const str m_options.load(optionFile); m_options.showOptions(); m_driver._hyperparams.setRequared(m_options); + + /% + IMPRTANT: the model is loaded just in the beginning; otherwise we used to get `Segmentation Fault` in the next initialization steps. + %/ + loadModelFile(modelFile); vector testInsts; m_pipe.readInstances(testFile, testInsts, m_options.maxInstance); - getDepFeats(testInsts, m_options.conllFolder + path_separator + "test.conll.predict"); + getDepFeats(testInsts, testFile + ".conll"); int word_count = 0, max_size; @@ -364,16 +369,52 @@ void DisParser::test(const string &testFile, const string &outputFile, const str extern_nodes.resize(word_count * 10); node_count = 0; - string syn = "conll.dump.results"; - getSynFeats(testInsts, m_options.dumpFolder + path_separator + "test." + syn); + getSynFeats(testInsts, testFile + ".dump"); + + /* + NOTE (Mat-sipahi): The following lines (down to `m_driver.initial`) are copied from `train` function + in order to avoid segmentations fault caused by uninitalized hyperparams in different parts of the code. + But I've commented out unnecessary steps, but I'm not sure if there are any other redindant steps. + */ + addTestAlpha(testInsts); + //createAlphabet(trainInsts); + //getGoldActions(trainInsts); + + //if(m_options.wordEmbFile == "") { + // m_driver._modelparams.edu_params.word_table.initial(&m_driver._hyperparams.wordAlpha, m_options.wordEmbSize, m_options.wordFineTune); + //} + //else + // m_driver._modelparams.edu_params.word_table.initial(&m_driver._hyperparams.wordAlpha, m_options.wordEmbFile, m_options.wordFineTune); + m_driver._hyperparams.wordDim = m_driver._modelparams.edu_params.word_table.nDim; + m_driver._modelparams.edu_params.tag_table.initial(&m_driver._hyperparams.tagAlpha, m_options.tagEmbSize, m_options.tagFineTune); + m_driver._hyperparams.tagDim = m_driver._modelparams.edu_params.tag_table.nDim; + + m_driver._hyperparams.wordConcatDim = m_driver._hyperparams.wordDim + m_driver._hyperparams.tagDim; + //m_driver._modelparams.etype_table.initial(&m_driver._hyperparams.etypeAlpha, m_options.etypeEmbSize, m_options.etypeFineTune); + m_driver._hyperparams.etypeDim = m_driver._modelparams.etype_table.nDim; + m_driver._hyperparams.eduConcatDim = m_driver._hyperparams.eduHiddenDim + m_driver._hyperparams.etypeDim; + //m_driver._hyperparams.eduConcatDim = m_driver._hyperparams.eduHiddenDim; m_driver._modelparams.edu_params.word_table.elems = &m_driver._hyperparams.wordAlpha; m_driver._modelparams.edu_params.tag_table.elems = &m_driver._hyperparams.tagAlpha; m_driver._modelparams.etype_table.elems = &m_driver._hyperparams.etypeAlpha; m_driver._modelparams.scored_action_table.elems = &m_driver._hyperparams.actionAlpha; + + //m_driver._hyperparams.actionAlpha.initial(m_driver._hyperparams.action_stat, 0); + m_driver._hyperparams.actionNum = m_driver._hyperparams.actionAlpha.size(); + //m_driver._hyperparams.etypeAlpha.initial(m_driver._hyperparams.etype_stat, 0); + m_driver.initial(); + + /* + End of the block copied from `train` (Mat-sipahi) + */ + + /* + IMPORTANT: model is loaded for the second time. The parser was predicting all relations as ENABLEMENT without it. + */ loadModelFile(modelFile); - getDepFeats(testInsts, m_options.conllFolder + path_separator + "test.conll.predict"); + vector decodeInstResults; int testNum = testInsts.size(); @@ -388,7 +429,9 @@ void DisParser::test(const string &testFile, const string &outputFile, const str test_nuclear.reset(); test_relation.reset(); test_full.reset(); + predict(testInsts, decodeInstResults); + for (int idx = 0; idx < testInsts.size(); idx++) { testInsts[idx].evaluate(decodeInstResults[idx], test_span, test_nuclear, test_relation, test_full); } @@ -404,9 +447,7 @@ void DisParser::test(const string &testFile, const string &outputFile, const str cout << "F: "; test_full.print(); - if (!m_options.outBest.empty()) { - m_pipe.outputAllInstances(testFile + m_options.outBest + ".test", decodeInstResults); - } + m_pipe.outputAllInstances(outputFile, decodeInstResults); } } @@ -558,7 +599,7 @@ void DisParser::getSynFeats(vector &vecInsts, const string &folder) { for (int i = 0; i < vec_size; i++) { vecInfo.clear(); split_bychar(vecLine1[i], vecInfo, ' '); - + assert(normalize_to_lowerwithdigit(vecInfo[0]).compare(cur_dep_feat.words[i]) == 0); } int cur_word_size = cur_dep_feat.words.size(), syn_offset; @@ -647,7 +688,7 @@ void DisParser::writeModelFile(const string &outputModelFile) { void DisParser::loadModelFile(const string &inputModelFile) { ifstream inf(inputModelFile.c_str()); m_driver._hyperparams.read(inf); - m_driver._modelparams.loadModel(inf); + m_driver._modelparams.loadModel(inf, m_driver._hyperparams); inf.close(); } @@ -669,7 +710,7 @@ void DisParser::predict(const vector &inputs, vector &outputs int main(int argc, char* argv[]) { - std::string trainFile = "", devFile = "", testFile = "", modelFile = ""; + std::string trainFile = "", devFile = "", testFile = "", modelFile = "./model.bin"; std::string optionFile = ""; std::string outputFile = ""; bool bTrain = false; diff --git a/NNDisParser/src/NNDisParser-synmlp-do/model/ModelParams.h b/NNDisParser/src/NNDisParser-synmlp-do/model/ModelParams.h index dfca1dd..b4f7dac 100644 --- a/NNDisParser/src/NNDisParser-synmlp-do/model/ModelParams.h +++ b/NNDisParser/src/NNDisParser-synmlp-do/model/ModelParams.h @@ -65,14 +65,15 @@ class ModelParams { edu_lstm_right_layer1_params.save(os); scored_action_table.save(os); } - - void loadModel(std::ifstream &is) { + + // NOTE: the model should read data from the initailized hyperparams + void loadModel(std::ifstream &is, HyperParams &opts) { edu_params.load(is); syn_params.load(is); - etype_table.load(is, etype_table.elems); + etype_table.load(is, &opts.etypeAlpha); edu_lstm_left_layer1_params.load(is); edu_lstm_right_layer1_params.load(is); - scored_action_table.load(is, scored_action_table.elems); + scored_action_table.load(is, &opts.actionAlpha); } }; diff --git a/NNDisParser/src/basic/Instance.h b/NNDisParser/src/basic/Instance.h index d9f825d..d867455 100644 --- a/NNDisParser/src/basic/Instance.h +++ b/NNDisParser/src/basic/Instance.h @@ -286,8 +286,9 @@ class Instance { EDU &edu = edus[idx]; assert(edu.start_index <= edu.end_index); assert(edu.start_index >= 0 && edu.end_index < total_text_size); - if (idx < edu_size - 1) + if (idx != edu_size - 1) { assert(edu.end_index + 1 == edus[idx + 1].start_index); + } for (int idy = edu.start_index; idy <= edu.end_index; idy++) { if(total_tags[idy] != nullkey){ edu.words.push_back(normalize_to_lowerwithdigit(total_text[idy]));