|
|
|
@ -17,36 +17,38 @@ limitations under the License. */
|
|
|
|
|
#include <fenv.h>
|
|
|
|
|
#include <stdio.h>
|
|
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <iomanip>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
|
|
|
|
|
#include <google/protobuf/text_format.h>
|
|
|
|
|
|
|
|
|
|
#include "paddle/utils/Excepts.h"
|
|
|
|
|
#include "paddle/utils/GlobalConstants.h"
|
|
|
|
|
#include "paddle/utils/PythonUtil.h"
|
|
|
|
|
#include "paddle/utils/Stat.h"
|
|
|
|
|
#include "paddle/utils/Util.h"
|
|
|
|
|
#include "paddle/utils/Excepts.h"
|
|
|
|
|
#include "paddle/utils/GlobalConstants.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
|
|
|
|
|
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
|
|
|
|
|
#include "paddle/gserver/layers/ValidationLayer.h"
|
|
|
|
|
#include "RemoteParameterUpdater.h"
|
|
|
|
|
#include "TesterConfig.h"
|
|
|
|
|
#include "ThreadParameterUpdater.h"
|
|
|
|
|
#include "RemoteParameterUpdater.h"
|
|
|
|
|
#include "TrainerConfigHelper.h"
|
|
|
|
|
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
|
|
|
|
|
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
|
|
|
|
|
#include "paddle/gserver/layers/ValidationLayer.h"
|
|
|
|
|
|
|
|
|
|
P_DEFINE_string(config, "", "Trainer config file");
|
|
|
|
|
|
|
|
|
|
P_DEFINE_int32(test_period, 0,
|
|
|
|
|
P_DEFINE_int32(test_period,
|
|
|
|
|
0,
|
|
|
|
|
"if equal 0, do test on all test data at the end of "
|
|
|
|
|
"each pass. While if equal non-zero, do test on all test "
|
|
|
|
|
"data every test_period batches");
|
|
|
|
|
P_DEFINE_bool(test_all_data_in_one_period, false,
|
|
|
|
|
"This option was deprecated, since we will always do "
|
|
|
|
|
"test on all test set ");
|
|
|
|
|
P_DEFINE_bool(test_all_data_in_one_period,
|
|
|
|
|
false,
|
|
|
|
|
"This option was deprecated, since we will always do "
|
|
|
|
|
"test on all test set ");
|
|
|
|
|
|
|
|
|
|
P_DEFINE_bool(local, true, "Train in local mode or not");
|
|
|
|
|
|
|
|
|
@ -392,10 +394,6 @@ void Trainer::startTrain() {
|
|
|
|
|
dataProvider_->reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (this->testDataProvider_) {
|
|
|
|
|
this->testDataProvider_->reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -630,16 +628,14 @@ void Trainer::test() { tester_->test(); }
|
|
|
|
|
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
|
|
|
|
|
TesterConfig* conf = new TesterConfig;
|
|
|
|
|
if (FLAGS_test_period) {
|
|
|
|
|
LOG(WARNING)
|
|
|
|
|
<< "The meaning of --test_period is changed: "
|
|
|
|
|
<< "if equal 0, do test on all test data at the end of "
|
|
|
|
|
<< "each pass. While if equal non-zero, do test on all test "
|
|
|
|
|
<< "data every test_period batches ";
|
|
|
|
|
LOG(WARNING) << "The meaning of --test_period is changed: "
|
|
|
|
|
<< "if equal 0, do test on all test data at the end of "
|
|
|
|
|
<< "each pass. While if equal non-zero, do test on all test "
|
|
|
|
|
<< "data every test_period batches ";
|
|
|
|
|
}
|
|
|
|
|
if (FLAGS_test_all_data_in_one_period) {
|
|
|
|
|
LOG(WARNING)
|
|
|
|
|
<< "--test_all_data_in_one_period was deprecated, since "
|
|
|
|
|
<< "we will always do test on all test set ";
|
|
|
|
|
LOG(WARNING) << "--test_all_data_in_one_period was deprecated, since "
|
|
|
|
|
<< "we will always do test on all test set ";
|
|
|
|
|
}
|
|
|
|
|
conf->testPeriod = FLAGS_test_period;
|
|
|
|
|
conf->prevBatchState = FLAGS_prev_batch_state;
|
|
|
|
|