You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/train/imdb_demo/demo_trainer.cc

185 lines
6.8 KiB

// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <time.h>
#include <fstream>
#include "include/save_model.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/dataset_factory.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "gflags/gflags.h"
DEFINE_string(filelist, "train_filelist.txt", "filelist for fluid dataset");
DEFINE_string(data_proto_desc, "data.proto", "data feed protobuf description");
DEFINE_string(startup_program_file, "startup_program",
"startup program description");
DEFINE_string(main_program_file, "", "main program description");
DEFINE_string(loss_name, "mean_0.tmp_0",
"loss tensor name in the main program");
DEFINE_string(save_dir, "cnn_model", "directory to save trained models");
DEFINE_int32(epoch_num, 30, "number of epochs to run when training");
namespace paddle {
namespace train {
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename);
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
std::unique_ptr<paddle::framework::ProgramDesc> LoadProgramDesc(
const std::string& model_filename) {
VLOG(3) << "loading model from " << model_filename;
std::string program_desc_str;
ReadBinaryFile(model_filename, &program_desc_str);
std::unique_ptr<paddle::framework::ProgramDesc> main_program(
new paddle::framework::ProgramDesc(program_desc_str));
return main_program;
}
bool IsPersistable(const paddle::framework::VarDesc* var) {
if (var->Persistable() &&
var->GetType() != paddle::framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != paddle::framework::proto::VarType::FETCH_LIST &&
var->GetType() != paddle::framework::proto::VarType::RAW) {
return true;
}
return false;
}
} // namespace train
} // namespace paddle
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
std::cerr << "filelist: " << FLAGS_filelist << std::endl;
std::cerr << "data_proto_desc: " << FLAGS_data_proto_desc << std::endl;
std::cerr << "startup_program_file: " << FLAGS_startup_program_file
<< std::endl;
std::cerr << "main_program_file: " << FLAGS_main_program_file << std::endl;
std::cerr << "loss_name: " << FLAGS_loss_name << std::endl;
std::cerr << "save_dir: " << FLAGS_save_dir << std::endl;
std::cerr << "epoch_num: " << FLAGS_epoch_num << std::endl;
std::string filelist = std::string(FLAGS_filelist);
std::vector<std::string> file_vec;
std::ifstream fin(filelist);
if (fin) {
std::string filename;
while (fin >> filename) {
file_vec.push_back(filename);
}
}
PADDLE_ENFORCE_GE(file_vec.size(), 1, "At least one file to train");
paddle::framework::InitDevices(false);
const auto cpu_place = paddle::platform::CPUPlace();
paddle::framework::Executor executor(cpu_place);
paddle::framework::Scope scope;
auto startup_program =
paddle::train::LoadProgramDesc(std::string(FLAGS_startup_program_file));
auto main_program =
paddle::train::LoadProgramDesc(std::string(FLAGS_main_program_file));
executor.Run(*startup_program, &scope, 0);
std::string data_feed_desc_str;
paddle::train::ReadBinaryFile(std::string(FLAGS_data_proto_desc),
&data_feed_desc_str);
VLOG(3) << "load data feed desc done.";
std::unique_ptr<paddle::framework::Dataset> dataset_ptr;
dataset_ptr =
paddle::framework::DatasetFactory::CreateDataset("MultiSlotDataset");
VLOG(3) << "initialize dataset ptr done";
// find all params
std::vector<std::string> param_names;
const paddle::framework::BlockDesc& global_block = main_program->Block(0);
for (auto* var : global_block.AllVars()) {
if (paddle::train::IsPersistable(var)) {
VLOG(3) << "persistable variable's name: " << var->Name();
param_names.push_back(var->Name());
}
}
int epoch_num = FLAGS_epoch_num;
std::string loss_name = FLAGS_loss_name;
auto loss_var = scope.Var(loss_name);
LOG(INFO) << "Start training...";
for (int epoch = 0; epoch < epoch_num; ++epoch) {
VLOG(3) << "Epoch:" << epoch;
// get reader
dataset_ptr->SetFileList(file_vec);
VLOG(3) << "set file list done";
dataset_ptr->SetThreadNum(1);
VLOG(3) << "set thread num done";
dataset_ptr->SetDataFeedDesc(data_feed_desc_str);
VLOG(3) << "set data feed desc done";
dataset_ptr->CreateReaders();
const std::vector<paddle::framework::DataFeed*> readers =
dataset_ptr->GetReaders();
PADDLE_ENFORCE_EQ(readers.size(), 1,
"readers num should be equal to thread num");
readers[0]->SetPlace(paddle::platform::CPUPlace());
const std::vector<std::string>& input_feed_names =
readers[0]->GetUseSlotAlias();
for (auto name : input_feed_names) {
readers[0]->AddFeedVar(scope.Var(name), name);
}
VLOG(3) << "get reader done";
readers[0]->Start();
VLOG(3) << "start a reader";
VLOG(3) << "readers size: " << readers.size();
int step = 0;
std::vector<float> loss_vec;
while (readers[0]->Next() > 0) {
executor.Run(*main_program, &scope, 0, false, true);
loss_vec.push_back(
loss_var->Get<paddle::framework::LoDTensor>().data<float>()[0]);
}
float average_loss =
accumulate(loss_vec.begin(), loss_vec.end(), 0.0) / loss_vec.size();
LOG(INFO) << "epoch: " << epoch << "; average loss: " << average_loss;
dataset_ptr->DestroyReaders();
// save model
std::string save_dir_root = FLAGS_save_dir;
std::string save_dir =
save_dir_root + "/epoch" + std::to_string(epoch) + ".model";
paddle::framework::save_model(main_program, &scope, param_names, save_dir,
false);
}
}