[WIP] Add Imdb train demo (#18895)
* add train demo for imdb text classification task * make inference library release data_feed dataset dataset_factory data_feed_factory * add String Data Generator * new feature of train demo: save model params * New feature of train demo: set training config using gflags * change code style for CI * add readme and dataset for imdb demo trainerpadding_in_crf
parent
0b1025769c
commit
4ad7c9d5a7
@ -0,0 +1,78 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
project(cpp_imdb_train_demo CXX C)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
|
||||
if(NOT DEFINED PADDLE_LIB)
|
||||
message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/paddle/lib/dir")
|
||||
endif()
|
||||
|
||||
option(WITH_MKLDNN "Compile PaddlePaddle with MKLDNN" OFF)
|
||||
option(WITH_MKL "Compile PaddlePaddle with MKL support, default use openblas." OFF)
|
||||
|
||||
include_directories("${PADDLE_LIB}")
|
||||
include_directories("${PADDLE_LIB}/third_party/install/protobuf/include")
|
||||
include_directories("${PADDLE_LIB}/third_party/install/glog/include")
|
||||
include_directories("${PADDLE_LIB}/third_party/install/gflags/include")
|
||||
include_directories("${PADDLE_LIB}/third_party/install/xxhash/include")
|
||||
include_directories("${PADDLE_LIB}/third_party/install/snappy/include")
|
||||
include_directories("${PADDLE_LIB}/third_party/install/snappystream/include")
|
||||
include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
|
||||
|
||||
include_directories("${PADDLE_LIB}/third_party/boost")
|
||||
include_directories("${PADDLE_LIB}/third_party/eigen3")
|
||||
|
||||
link_directories("${PADDLE_LIB}/third_party/install/snappy/lib")
|
||||
link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib")
|
||||
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
|
||||
link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
|
||||
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
|
||||
link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib")
|
||||
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
|
||||
|
||||
add_executable(demo_trainer save_model.cc demo_trainer.cc)
|
||||
|
||||
if(WITH_MKLDNN)
|
||||
include_directories("${PADDLE_LIB}/third_party/install/mkldnn/include")
|
||||
if(WIN32)
|
||||
set(MKLDNN_LIB ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.lib)
|
||||
else(WIN32)
|
||||
set(MKLDNN_LIB ${PADDLE_LIB}/third_party/install/mkldnn/lib/libmkldnn.so.0)
|
||||
endif(WIN32)
|
||||
endif(WITH_MKLDNN)
|
||||
|
||||
if(WITH_MKL)
|
||||
include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
|
||||
if(WIN32)
|
||||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.lib)
|
||||
else(WIN32)
|
||||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so)
|
||||
endif(WIN32)
|
||||
else()
|
||||
if(APPLE)
|
||||
set(MATH_LIB cblas)
|
||||
elseif(WIN32)
|
||||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.lib)
|
||||
else()
|
||||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a)
|
||||
endif(APPLE)
|
||||
endif()
|
||||
|
||||
if(APPLE)
|
||||
set(MACOS_LD_FLAGS "-undefined dynamic_lookup -Wl,-all_load -framework CoreFoundation -framework Security")
|
||||
else(APPLE)
|
||||
set(ARCHIVE_START "-Wl,--whole-archive")
|
||||
set(ARCHIVE_END "-Wl,--no-whole-archive")
|
||||
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
|
||||
endif(APPLE)
|
||||
|
||||
target_link_libraries(demo_trainer
|
||||
${MACOS_LD_FLAGS}
|
||||
${ARCHIVE_START}
|
||||
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.so
|
||||
${ARCHIVE_END}
|
||||
${MATH_LIB}
|
||||
${MKLDNN_LIB}
|
||||
glog gflags protobuf snappystream snappy z xxhash
|
||||
${EXTERNAL_LIB})
|
@ -0,0 +1,97 @@
|
||||
# Train with C++ inference API
|
||||
|
||||
What is C++ inference API and how to install it:
|
||||
|
||||
see: [PaddlePaddle Fluid 提供了 C++ API 来支持模型的部署上线](https://paddlepaddle.org.cn/documentation/docs/zh/1.5/advanced_usage/deploy/inference/index_cn.html)
|
||||
|
||||
## IMDB task
|
||||
|
||||
see: [IMDB Dataset of 50K Movie Reviews | Kaggle](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### prepare data
|
||||
|
||||
```shell
|
||||
wget https://fleet.bj.bcebos.com/text_classification_data.tar.gz
|
||||
tar -zxvf text_classification_data.tar.gz
|
||||
```
|
||||
### build
|
||||
|
||||
```shell
|
||||
mkdir build
|
||||
cd build
|
||||
rm -rf *
|
||||
PADDLE_LIB=path/to/your/fluid_inference_install_dir/
|
||||
cmake .. -DPADDLE_LIB=$PADDLE_LIB -DWITH_MKLDNN=OFF -DWITH_MKL=OFF
|
||||
make
|
||||
```
|
||||
|
||||
### generate program description
|
||||
|
||||
```
|
||||
python generate_program.py bow
|
||||
```
|
||||
|
||||
### run
|
||||
|
||||
```shell
|
||||
# After editing train.cfg
|
||||
sh run.sh
|
||||
```
|
||||
|
||||
## results
|
||||
|
||||
Below are training logs on BOW model, the losses go down as expected.
|
||||
|
||||
```
|
||||
WARNING: Logging before InitGoogleLogging() is written to STDERR
|
||||
I0731 22:39:06.974232 10965 demo_trainer.cc:130] Start training...
|
||||
I0731 22:39:57.395229 10965 demo_trainer.cc:164] epoch: 0; average loss: 0.405706
|
||||
I0731 22:40:50.262344 10965 demo_trainer.cc:164] epoch: 1; average loss: 0.110746
|
||||
I0731 22:41:49.731079 10965 demo_trainer.cc:164] epoch: 2; average loss: 0.0475805
|
||||
I0731 22:43:31.398355 10965 demo_trainer.cc:164] epoch: 3; average loss: 0.0233249
|
||||
I0731 22:44:58.744391 10965 demo_trainer.cc:164] epoch: 4; average loss: 0.00701507
|
||||
I0731 22:46:30.451735 10965 demo_trainer.cc:164] epoch: 5; average loss: 0.00258187
|
||||
I0731 22:48:14.396687 10965 demo_trainer.cc:164] epoch: 6; average loss: 0.00113157
|
||||
I0731 22:49:56.242744 10965 demo_trainer.cc:164] epoch: 7; average loss: 0.000698234
|
||||
I0731 22:51:11.585919 10965 demo_trainer.cc:164] epoch: 8; average loss: 0.000510136
|
||||
I0731 22:52:50.573947 10965 demo_trainer.cc:164] epoch: 9; average loss: 0.000400932
|
||||
I0731 22:54:02.686152 10965 demo_trainer.cc:164] epoch: 10; average loss: 0.000329259
|
||||
I0731 22:54:55.233342 10965 demo_trainer.cc:164] epoch: 11; average loss: 0.000278644
|
||||
I0731 22:56:15.496256 10965 demo_trainer.cc:164] epoch: 12; average loss: 0.000241055
|
||||
I0731 22:57:45.015926 10965 demo_trainer.cc:164] epoch: 13; average loss: 0.000212085
|
||||
I0731 22:59:18.419997 10965 demo_trainer.cc:164] epoch: 14; average loss: 0.000189109
|
||||
I0731 23:00:15.409077 10965 demo_trainer.cc:164] epoch: 15; average loss: 0.000170465
|
||||
I0731 23:01:38.795770 10965 demo_trainer.cc:164] epoch: 16; average loss: 0.000155051
|
||||
I0731 23:02:57.289487 10965 demo_trainer.cc:164] epoch: 17; average loss: 0.000142106
|
||||
I0731 23:03:48.032507 10965 demo_trainer.cc:164] epoch: 18; average loss: 0.000131089
|
||||
I0731 23:04:51.195230 10965 demo_trainer.cc:164] epoch: 19; average loss: 0.000121605
|
||||
I0731 23:06:27.008040 10965 demo_trainer.cc:164] epoch: 20; average loss: 0.00011336
|
||||
I0731 23:07:56.568284 10965 demo_trainer.cc:164] epoch: 21; average loss: 0.000106129
|
||||
I0731 23:09:23.948290 10965 demo_trainer.cc:164] epoch: 22; average loss: 9.97393e-05
|
||||
I0731 23:10:56.062590 10965 demo_trainer.cc:164] epoch: 23; average loss: 9.40532e-05
|
||||
I0731 23:12:23.014047 10965 demo_trainer.cc:164] epoch: 24; average loss: 8.89622e-05
|
||||
I0731 23:13:21.439818 10965 demo_trainer.cc:164] epoch: 25; average loss: 8.43784e-05
|
||||
I0731 23:14:56.171597 10965 demo_trainer.cc:164] epoch: 26; average loss: 8.02322e-05
|
||||
I0731 23:16:01.513542 10965 demo_trainer.cc:164] epoch: 27; average loss: 7.64629e-05
|
||||
I0731 23:17:18.709139 10965 demo_trainer.cc:164] epoch: 28; average loss: 7.30239e-05
|
||||
I0731 23:18:41.421555 10965 demo_trainer.cc:164] epoch: 29; average loss: 6.98716e-05
|
||||
```
|
||||
|
||||
I trained a Bow model and a CNN model on IMDB dataset using the trainer. At the same time, I also trained the same models using traditional Python training methods.
|
||||
Results show that the two methods achieve almost the same dev accuracy:
|
||||
|
||||
CNN:
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/23031310/62356234-32217300-b543-11e9-89fd-a07614904a08.png" width="300">
|
||||
|
||||
BOW:
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/23031310/62356253-39488100-b543-11e9-9fa2-a399fc1119d6.png" width="300">
|
||||
|
||||
I also recorded the training speed of the C++ Trainer and the python training methods, C++ trainer is quicker on CNN model:
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/23031310/62356444-af4ce800-b543-11e9-88c8-f3bde1321ea1.png" width="300">
|
||||
|
||||
#TODO (mapingshuo): find the reason why C++ trainer is quicker on CNN model than python method.
|
@ -0,0 +1,183 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <time.h>
|
||||
#include <fstream>
|
||||
|
||||
#include "include/save_model.h"
|
||||
#include "paddle/fluid/framework/data_feed_factory.h"
|
||||
#include "paddle/fluid/framework/dataset_factory.h"
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/framework/variable_helper.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#include "paddle/fluid/platform/init.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "paddle/fluid/platform/profiler.h"
|
||||
|
||||
#include "gflags/gflags.h"
|
||||
|
||||
DEFINE_string(filelist, "train_filelist.txt", "filelist for fluid dataset");
|
||||
DEFINE_string(data_proto_desc, "data.proto", "data feed protobuf description");
|
||||
DEFINE_string(startup_program_file, "startup_program",
|
||||
"startup program description");
|
||||
DEFINE_string(main_program_file, "", "main program description");
|
||||
DEFINE_string(loss_name, "mean_0.tmp_0",
|
||||
"loss tensor name in the main program");
|
||||
DEFINE_string(save_dir, "cnn_model", "directory to save trained models");
|
||||
DEFINE_int32(epoch_num, 30, "number of epochs to run when training");
|
||||
|
||||
namespace paddle {
|
||||
namespace train {
|
||||
|
||||
void ReadBinaryFile(const std::string& filename, std::string* contents) {
|
||||
std::ifstream fin(filename, std::ios::in | std::ios::binary);
|
||||
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename);
|
||||
fin.seekg(0, std::ios::end);
|
||||
contents->clear();
|
||||
contents->resize(fin.tellg());
|
||||
fin.seekg(0, std::ios::beg);
|
||||
fin.read(&(contents->at(0)), contents->size());
|
||||
fin.close();
|
||||
}
|
||||
|
||||
std::unique_ptr<paddle::framework::ProgramDesc> LoadProgramDesc(
|
||||
const std::string& model_filename) {
|
||||
VLOG(3) << "loading model from " << model_filename;
|
||||
std::string program_desc_str;
|
||||
ReadBinaryFile(model_filename, &program_desc_str);
|
||||
std::unique_ptr<paddle::framework::ProgramDesc> main_program(
|
||||
new paddle::framework::ProgramDesc(program_desc_str));
|
||||
return main_program;
|
||||
}
|
||||
|
||||
bool IsPersistable(const paddle::framework::VarDesc* var) {
|
||||
if (var->Persistable() &&
|
||||
var->GetType() != paddle::framework::proto::VarType::FEED_MINIBATCH &&
|
||||
var->GetType() != paddle::framework::proto::VarType::FETCH_LIST &&
|
||||
var->GetType() != paddle::framework::proto::VarType::RAW) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace train
|
||||
} // namespace paddle
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
|
||||
std::cerr << "filelist: " << FLAGS_filelist << std::endl;
|
||||
std::cerr << "data_proto_desc: " << FLAGS_data_proto_desc << std::endl;
|
||||
std::cerr << "startup_program_file: " << FLAGS_startup_program_file
|
||||
<< std::endl;
|
||||
std::cerr << "main_program_file: " << FLAGS_main_program_file << std::endl;
|
||||
std::cerr << "loss_name: " << FLAGS_loss_name << std::endl;
|
||||
std::cerr << "save_dir: " << FLAGS_save_dir << std::endl;
|
||||
std::cerr << "epoch_num: " << FLAGS_epoch_num << std::endl;
|
||||
|
||||
std::string filelist = std::string(FLAGS_filelist);
|
||||
std::vector<std::string> file_vec;
|
||||
std::ifstream fin(filelist);
|
||||
if (fin) {
|
||||
std::string filename;
|
||||
while (fin >> filename) {
|
||||
file_vec.push_back(filename);
|
||||
}
|
||||
}
|
||||
PADDLE_ENFORCE_GE(file_vec.size(), 1, "At least one file to train");
|
||||
paddle::framework::InitDevices(false);
|
||||
const auto cpu_place = paddle::platform::CPUPlace();
|
||||
paddle::framework::Executor executor(cpu_place);
|
||||
paddle::framework::Scope scope;
|
||||
auto startup_program =
|
||||
paddle::train::LoadProgramDesc(std::string(FLAGS_startup_program_file));
|
||||
auto main_program =
|
||||
paddle::train::LoadProgramDesc(std::string(FLAGS_main_program_file));
|
||||
|
||||
executor.Run(*startup_program, &scope, 0);
|
||||
|
||||
std::string data_feed_desc_str;
|
||||
paddle::train::ReadBinaryFile(std::string(FLAGS_data_proto_desc),
|
||||
&data_feed_desc_str);
|
||||
VLOG(3) << "load data feed desc done.";
|
||||
std::unique_ptr<paddle::framework::Dataset> dataset_ptr;
|
||||
dataset_ptr =
|
||||
paddle::framework::DatasetFactory::CreateDataset("MultiSlotDataset");
|
||||
VLOG(3) << "initialize dataset ptr done";
|
||||
|
||||
// find all params
|
||||
std::vector<std::string> param_names;
|
||||
const paddle::framework::BlockDesc& global_block = main_program->Block(0);
|
||||
for (auto* var : global_block.AllVars()) {
|
||||
if (paddle::train::IsPersistable(var)) {
|
||||
VLOG(3) << "persistable variable's name: " << var->Name();
|
||||
param_names.push_back(var->Name());
|
||||
}
|
||||
}
|
||||
|
||||
int epoch_num = FLAGS_epoch_num;
|
||||
std::string loss_name = FLAGS_loss_name;
|
||||
auto loss_var = scope.Var(loss_name);
|
||||
|
||||
LOG(INFO) << "Start training...";
|
||||
|
||||
for (int epoch = 0; epoch < epoch_num; ++epoch) {
|
||||
VLOG(3) << "Epoch:" << epoch;
|
||||
// get reader
|
||||
dataset_ptr->SetFileList(file_vec);
|
||||
VLOG(3) << "set file list done";
|
||||
dataset_ptr->SetThreadNum(1);
|
||||
VLOG(3) << "set thread num done";
|
||||
dataset_ptr->SetDataFeedDesc(data_feed_desc_str);
|
||||
VLOG(3) << "set data feed desc done";
|
||||
dataset_ptr->CreateReaders();
|
||||
const std::vector<paddle::framework::DataFeed*> readers =
|
||||
dataset_ptr->GetReaders();
|
||||
PADDLE_ENFORCE_EQ(readers.size(), 1,
|
||||
"readers num should be equal to thread num");
|
||||
const std::vector<std::string>& input_feed_names =
|
||||
readers[0]->GetUseSlotAlias();
|
||||
for (auto name : input_feed_names) {
|
||||
readers[0]->AddFeedVar(scope.Var(name), name);
|
||||
}
|
||||
VLOG(3) << "get reader done";
|
||||
readers[0]->Start();
|
||||
VLOG(3) << "start a reader";
|
||||
VLOG(3) << "readers size: " << readers.size();
|
||||
|
||||
int step = 0;
|
||||
std::vector<float> loss_vec;
|
||||
|
||||
while (readers[0]->Next() > 0) {
|
||||
executor.Run(*main_program, &scope, 0, false, true);
|
||||
loss_vec.push_back(
|
||||
loss_var->Get<paddle::framework::LoDTensor>().data<float>()[0]);
|
||||
}
|
||||
float average_loss =
|
||||
accumulate(loss_vec.begin(), loss_vec.end(), 0.0) / loss_vec.size();
|
||||
|
||||
LOG(INFO) << "epoch: " << epoch << "; average loss: " << average_loss;
|
||||
dataset_ptr->DestroyReaders();
|
||||
|
||||
// save model
|
||||
std::string save_dir_root = FLAGS_save_dir;
|
||||
std::string save_dir =
|
||||
save_dir_root + "/epoch" + std::to_string(epoch) + ".model";
|
||||
paddle::framework::save_model(main_program, &scope, param_names, save_dir,
|
||||
false);
|
||||
}
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
import paddle
|
||||
import logging
|
||||
import paddle.fluid as fluid
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger("fluid")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def load_vocab(filename):
|
||||
vocab = {}
|
||||
with open(filename) as f:
|
||||
wid = 0
|
||||
for line in f:
|
||||
vocab[line.strip()] = wid
|
||||
wid += 1
|
||||
vocab["<unk>"] = len(vocab)
|
||||
return vocab
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vocab = load_vocab('imdb.vocab')
|
||||
dict_dim = len(vocab)
|
||||
model_name = sys.argv[1]
|
||||
data = fluid.layers.data(
|
||||
name="words", shape=[1], dtype="int64", lod_level=1)
|
||||
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
|
||||
|
||||
dataset = fluid.DatasetFactory().create_dataset()
|
||||
dataset.set_batch_size(128)
|
||||
dataset.set_pipe_command("python imdb_reader.py")
|
||||
|
||||
dataset.set_use_var([data, label])
|
||||
desc = dataset.proto_desc
|
||||
|
||||
with open("data.proto", "w") as f:
|
||||
f.write(dataset.desc())
|
||||
|
||||
from nets import *
|
||||
if model_name == 'cnn':
|
||||
logger.info("Generate program description of CNN net")
|
||||
avg_cost, acc, prediction = cnn_net(data, label, dict_dim)
|
||||
elif model_name == 'bow':
|
||||
logger.info("Generate program description of BOW net")
|
||||
avg_cost, acc, prediction = bow_net(data, label, dict_dim)
|
||||
else:
|
||||
logger.error("no such model: " + model_name)
|
||||
exit(0)
|
||||
# optimizer = fluid.optimizer.SGD(learning_rate=0.01)
|
||||
optimizer = fluid.optimizer.Adagrad(learning_rate=0.01)
|
||||
optimizer.minimize(avg_cost)
|
||||
|
||||
with open(model_name + "_main_program", "wb") as f:
|
||||
f.write(fluid.default_main_program().desc.serialize_to_string())
|
||||
|
||||
with open(model_name + "_startup_program", "wb") as f:
|
||||
f.write(fluid.default_startup_program().desc.serialize_to_string())
|
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import os
|
||||
import paddle
|
||||
import re
|
||||
import paddle.fluid.incubate.data_generator as dg
|
||||
|
||||
|
||||
class IMDBDataset(dg.MultiSlotDataGenerator):
|
||||
def load_resource(self, dictfile):
|
||||
self._vocab = {}
|
||||
wid = 0
|
||||
with open(dictfile) as f:
|
||||
for line in f:
|
||||
self._vocab[line.strip()] = wid
|
||||
wid += 1
|
||||
self._unk_id = len(self._vocab)
|
||||
self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
|
||||
self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0])
|
||||
|
||||
def get_words_and_label(self, line):
|
||||
send = '|'.join(line.split('|')[:-1]).lower().replace("<br />",
|
||||
" ").strip()
|
||||
label = [int(line.split('|')[-1])]
|
||||
|
||||
words = [x for x in self._pattern.split(send) if x and x != " "]
|
||||
feas = [
|
||||
self._vocab[x] if x in self._vocab else self._unk_id for x in words
|
||||
]
|
||||
return feas, label
|
||||
|
||||
def infer_reader(self, infer_filelist, batch, buf_size):
|
||||
def local_iter():
|
||||
for fname in infer_filelist:
|
||||
with open(fname, "r") as fin:
|
||||
for line in fin:
|
||||
feas, label = self.get_words_and_label(line)
|
||||
yield feas, label
|
||||
|
||||
import paddle
|
||||
batch_iter = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
local_iter, buf_size=buf_size),
|
||||
batch_size=batch)
|
||||
return batch_iter
|
||||
|
||||
def generate_sample(self, line):
|
||||
def memory_iter():
|
||||
for i in range(1000):
|
||||
yield self.return_value
|
||||
|
||||
def data_iter():
|
||||
feas, label = self.get_words_and_label(line)
|
||||
yield ("words", feas), ("label", label)
|
||||
|
||||
return data_iter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
imdb = IMDBDataset()
|
||||
imdb.load_resource("imdb.vocab")
|
||||
imdb.run_from_stdin()
|
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <fcntl.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#include <google/protobuf/message.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "gflags/gflags.h"
|
||||
#include "paddle/fluid/framework/feed_fetch_method.h"
|
||||
#include "paddle/fluid/framework/feed_fetch_type.h"
|
||||
#include "paddle/fluid/framework/lod_rank_table.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/prune.h"
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
void save_model(const std::unique_ptr<ProgramDesc>& main_program, Scope* scope,
|
||||
const std::vector<std::string>& param_names,
|
||||
const std::string& model_name, bool save_combine);
|
||||
}
|
||||
}
|
@ -0,0 +1,140 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
def bow_net(data,
|
||||
label,
|
||||
dict_dim,
|
||||
emb_dim=128,
|
||||
hid_dim=128,
|
||||
hid_dim2=96,
|
||||
class_dim=2):
|
||||
"""
|
||||
bow net
|
||||
"""
|
||||
emb = fluid.layers.embedding(
|
||||
input=data, size=[dict_dim, emb_dim], is_sparse=True)
|
||||
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
|
||||
bow_tanh = fluid.layers.tanh(bow)
|
||||
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
|
||||
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
|
||||
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
|
||||
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
acc = fluid.layers.accuracy(input=prediction, label=label)
|
||||
|
||||
return avg_cost, acc, prediction
|
||||
|
||||
|
||||
def cnn_net(data,
|
||||
label,
|
||||
dict_dim,
|
||||
emb_dim=128,
|
||||
hid_dim=128,
|
||||
hid_dim2=96,
|
||||
class_dim=2,
|
||||
win_size=3):
|
||||
"""
|
||||
conv net
|
||||
"""
|
||||
emb = fluid.layers.embedding(
|
||||
input=data, size=[dict_dim, emb_dim], is_sparse=True)
|
||||
conv_3 = fluid.nets.sequence_conv_pool(
|
||||
input=emb,
|
||||
num_filters=hid_dim,
|
||||
filter_size=win_size,
|
||||
act="tanh",
|
||||
pool_type="max")
|
||||
|
||||
fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2)
|
||||
|
||||
prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax")
|
||||
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
acc = fluid.layers.accuracy(input=prediction, label=label)
|
||||
|
||||
return avg_cost, acc, prediction
|
||||
|
||||
|
||||
def lstm_net(data,
|
||||
label,
|
||||
dict_dim,
|
||||
emb_dim=128,
|
||||
hid_dim=128,
|
||||
hid_dim2=96,
|
||||
class_dim=2,
|
||||
emb_lr=30.0):
|
||||
"""
|
||||
lstm net
|
||||
"""
|
||||
emb = fluid.layers.embedding(
|
||||
input=data,
|
||||
size=[dict_dim, emb_dim],
|
||||
param_attr=fluid.ParamAttr(learning_rate=emb_lr),
|
||||
is_sparse=True)
|
||||
|
||||
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
|
||||
|
||||
lstm_h, c = fluid.layers.dynamic_lstm(
|
||||
input=fc0, size=hid_dim * 4, is_reverse=False)
|
||||
|
||||
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
|
||||
lstm_max_tanh = fluid.layers.tanh(lstm_max)
|
||||
|
||||
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
|
||||
|
||||
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
|
||||
|
||||
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
acc = fluid.layers.accuracy(input=prediction, label=label)
|
||||
|
||||
return avg_cost, acc, prediction
|
||||
|
||||
|
||||
def gru_net(data,
|
||||
label,
|
||||
dict_dim,
|
||||
emb_dim=128,
|
||||
hid_dim=128,
|
||||
hid_dim2=96,
|
||||
class_dim=2,
|
||||
emb_lr=400.0):
|
||||
"""
|
||||
gru net
|
||||
"""
|
||||
emb = fluid.layers.embedding(
|
||||
input=data,
|
||||
size=[dict_dim, emb_dim],
|
||||
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
|
||||
|
||||
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
|
||||
gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
|
||||
gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
|
||||
gru_max_tanh = fluid.layers.tanh(gru_max)
|
||||
fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
|
||||
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
|
||||
|
||||
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
acc = fluid.layers.accuracy(input=prediction, label=label)
|
||||
|
||||
return avg_cost, acc, prediction
|
@ -0,0 +1,3 @@
|
||||
|
||||
set -exu
|
||||
build/demo_trainer --flagfile="train.cfg"
|
@ -0,0 +1,77 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "include/save_model.h"
|
||||
#include <fcntl.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#include <google/protobuf/message.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include "gflags/gflags.h"
|
||||
#include "paddle/fluid/framework/feed_fetch_method.h"
|
||||
#include "paddle/fluid/framework/feed_fetch_type.h"
|
||||
#include "paddle/fluid/framework/lod_rank_table.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/prune.h"
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
using std::unique_ptr;
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
void save_model(const unique_ptr<ProgramDesc>& main_program, Scope* scope,
|
||||
const std::vector<std::string>& param_names,
|
||||
const std::string& model_name, bool save_combine) {
|
||||
auto place = platform::CPUPlace();
|
||||
const BlockDesc& global_block = main_program->Block(0);
|
||||
std::vector<std::string> paralist;
|
||||
for (auto* var : global_block.AllVars()) {
|
||||
bool is_model_param = false;
|
||||
for (auto param_name : param_names) {
|
||||
if (var->Name() == param_name) {
|
||||
is_model_param = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_model_param) continue;
|
||||
|
||||
if (!save_combine) {
|
||||
VLOG(3) << "model var name: %s" << var->Name().c_str();
|
||||
|
||||
paddle::framework::AttributeMap attrs;
|
||||
attrs.insert({"file_path", model_name + "/" + var->Name()});
|
||||
auto save_op = paddle::framework::OpRegistry::CreateOp(
|
||||
"save", {{"X", {var->Name()}}}, {}, attrs);
|
||||
|
||||
save_op->Run(*scope, place);
|
||||
} else {
|
||||
paralist.push_back(var->Name());
|
||||
}
|
||||
}
|
||||
if (save_combine) {
|
||||
std::sort(paralist.begin(), paralist.end());
|
||||
paddle::framework::AttributeMap attrs;
|
||||
attrs.insert({"file_path", model_name});
|
||||
auto save_op = paddle::framework::OpRegistry::CreateOp(
|
||||
"save_combine", {{"X", paralist}}, {}, attrs);
|
||||
save_op->Run(*scope, place);
|
||||
}
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,7 @@
|
||||
--filelist=train_filelist.txt
|
||||
--data_proto_desc=data.proto
|
||||
--loss_name=mean_0.tmp_0
|
||||
--startup_program_file=bow_startup_program
|
||||
--main_program_file=bow_main_program
|
||||
--save_dir=bow_model
|
||||
--epoch_num=30
|
@ -0,0 +1,12 @@
|
||||
train_data/part-0
|
||||
train_data/part-1
|
||||
train_data/part-10
|
||||
train_data/part-11
|
||||
train_data/part-2
|
||||
train_data/part-3
|
||||
train_data/part-4
|
||||
train_data/part-5
|
||||
train_data/part-6
|
||||
train_data/part-7
|
||||
train_data/part-8
|
||||
train_data/part-9
|
Loading…
Reference in new issue