|
|
|
@ -40,31 +40,28 @@ limitations under the License. */
|
|
|
|
|
#include "TrainerConfigHelper.h"
|
|
|
|
|
|
|
|
|
|
P_DEFINE_string(config, "", "Trainer config file");
|
|
|
|
|
P_DEFINE_int32(test_period, 0,
|
|
|
|
|
"Run test every so many train batches."
|
|
|
|
|
" 0 for testing after each pass."
|
|
|
|
|
" If not 0, test log_period batches."
|
|
|
|
|
" If 0, test on all test data");
|
|
|
|
|
|
|
|
|
|
P_DEFINE_int32(test_batches_while_training, 0,
|
|
|
|
|
P_DEFINE_int32(test_period, 0,
|
|
|
|
|
"This option was deprecated, use test_period_while_training "
|
|
|
|
|
" instead. ");
|
|
|
|
|
P_DEFINE_int32(test_period_while_training, 0,
|
|
|
|
|
"Run test every so many train batches."
|
|
|
|
|
" 0 for testing after each pass."
|
|
|
|
|
" If not 0, test log_period batches."
|
|
|
|
|
" If 0, test nothing.");
|
|
|
|
|
P_DEFINE_int32(test_batches_while_training, 1000,
|
|
|
|
|
"test test_batches_while_training batches if test_period != 0."
|
|
|
|
|
" If 0, test on all test data");
|
|
|
|
|
|
|
|
|
|
P_DEFINE_int32(test_batches_while_end, 0,
|
|
|
|
|
"Run test every so many train batches."
|
|
|
|
|
" 0 for testing after each pass."
|
|
|
|
|
" If not 0, test log_period batches."
|
|
|
|
|
" If 0, test on all test data");
|
|
|
|
|
"test test_batches_while_end batches at pass end."
|
|
|
|
|
" Always run test at pass end."
|
|
|
|
|
" If not 0, test test_batches_while_end batches."
|
|
|
|
|
" If 0, test on all test data.");
|
|
|
|
|
P_DEFINE_bool(test_all_data_in_one_period, false,
|
|
|
|
|
"This option was deprecated, use test_batches_while_training "
|
|
|
|
|
"and test_batches_while_end instead");
|
|
|
|
|
|
|
|
|
|
P_DEFINE_bool(local, true, "Train in local mode or not");
|
|
|
|
|
|
|
|
|
|
P_DEFINE_bool(
|
|
|
|
|
test_all_data_in_one_period, false,
|
|
|
|
|
"true will test all data in one test peroid."
|
|
|
|
|
"Otherwise test (batch_size * log_peroid) data in one test period.");
|
|
|
|
|
|
|
|
|
|
P_DEFINE_int32(average_test_period, 0,
|
|
|
|
|
"Do test on average parameter every so"
|
|
|
|
|
" many batches. MUST be devided by FLAGS_log_period."
|
|
|
|
@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
|
|
|
|
|
FOR_TIMING(globalStat.reset());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (testDataProvider_ && FLAGS_test_period > 0 &&
|
|
|
|
|
trainPassContext_.batchId % FLAGS_test_period == 0) {
|
|
|
|
|
tester_->testOnePeriod();
|
|
|
|
|
if (testDataProvider_ && FLAGS_test_period_while_training > 0 &&
|
|
|
|
|
trainPassContext_.batchId % FLAGS_test_period_while_training == 0) {
|
|
|
|
|
tester_->testOnePeriod(false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_saving_period_by_batches > 0 &&
|
|
|
|
@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
|
|
|
|
|
0 == FLAGS_trainer_id) {
|
|
|
|
|
trainerInternal_.getParameterUpdater()->catchUpWith();
|
|
|
|
|
if (testDataProvider_) {
|
|
|
|
|
tester_->testOnePeriod();
|
|
|
|
|
tester_->testOnePeriod(false);
|
|
|
|
|
}
|
|
|
|
|
paramUtil_->saveParametersOnePass(
|
|
|
|
|
trainPassContext_.passId, trainPassContext_.passInnerId);
|
|
|
|
@ -636,8 +633,19 @@ void Trainer::test() {
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
|
|
|
|
|
TesterConfig* conf = new TesterConfig;
|
|
|
|
|
conf->testPeriod = FLAGS_test_period;
|
|
|
|
|
conf->testAllDataInOnePeriod = FLAGS_test_all_data_in_one_period;
|
|
|
|
|
if (FLAGS_test_period) {
|
|
|
|
|
LOG(WARNING)
|
|
|
|
|
<< "--test_period was deprecated, use --test_period_while_training"
|
|
|
|
|
<< "--test_batches_while_training --test_batches_while_end instead.";
|
|
|
|
|
}
|
|
|
|
|
if (FLAGS_test_all_data_in_one_period) {
|
|
|
|
|
LOG(WARNING)
|
|
|
|
|
<< "--test_all_data_in_one_period was deprecated, use"
|
|
|
|
|
<< " --test_batches_while_training and --test_batches_while_end instead";
|
|
|
|
|
}
|
|
|
|
|
conf->testPeriodWhileTraining = FLAGS_test_period_while_training;
|
|
|
|
|
conf->testBatchesWhileTraining = FLAGS_test_batches_while_training;
|
|
|
|
|
conf->testBatchesWhileEnd = FLAGS_test_batches_while_end;
|
|
|
|
|
conf->prevBatchState = FLAGS_prev_batch_state;
|
|
|
|
|
conf->logPeriod = FLAGS_log_period;
|
|
|
|
|
conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;
|
|
|
|
|