|
|
|
@ -25,19 +25,37 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
|
|
|
|
|
std::string model_filename = dirname + "/__model__.dat";
|
|
|
|
|
LOG(INFO) << "loading model from " << model_filename;
|
|
|
|
|
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
|
|
|
|
|
std::string program_desc_str;
|
|
|
|
|
inputfs.seekg(0, std::ios::end);
|
|
|
|
|
program_desc_str.resize(inputfs.tellg());
|
|
|
|
|
inputfs.seekg(0, std::ios::beg);
|
|
|
|
|
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
|
|
|
|
|
inputfs.read(&program_desc_str[0], program_desc_str.size());
|
|
|
|
|
inputfs.close();
|
|
|
|
|
|
|
|
|
|
program_ = new framework::ProgramDesc(program_desc_str);
|
|
|
|
|
GenerateLoadProgram(dirname);
|
|
|
|
|
|
|
|
|
|
framework::BlockDesc* global_block = program_->MutableBlock(0);
|
|
|
|
|
feed_var_names_.clear();
|
|
|
|
|
fetch_var_names_.clear();
|
|
|
|
|
for (auto* op : global_block->AllOps()) {
|
|
|
|
|
if (op->Type() == "feed") {
|
|
|
|
|
feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]);
|
|
|
|
|
} else if (op->Type() == "fetch") {
|
|
|
|
|
fetch_var_names_.push_back(op->Input("X")[0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InferenceEngine::LoadInferenceModel(
|
|
|
|
|
const std::string& dirname,
|
|
|
|
|
const std::vector<std::string>& feed_var_names,
|
|
|
|
|
const std::vector<std::string>& fetch_var_names) {
|
|
|
|
|
#ifdef PADDLE_USE_PTOOLS
|
|
|
|
|
std::string model_filename = dirname + "/__model__";
|
|
|
|
|
LOG(INFO) << "Using PicklingTools, loading model from " << model_filename;
|
|
|
|
|
Val v;
|
|
|
|
|
LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0);
|
|
|
|
|
std::string program_desc_str = v["program_desc_str"];
|
|
|
|
|
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
|
|
|
|
|
// PicklingTools cannot parse the vector of strings correctly.
|
|
|
|
|
#else
|
|
|
|
|
std::string model_filename = dirname + "/__model__.dat";
|
|
|
|
|
LOG(INFO) << "loading model from " << model_filename;
|
|
|
|
|
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
|
|
|
|
@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel(
|
|
|
|
|
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
|
|
|
|
|
inputfs.read(&program_desc_str[0], program_desc_str.size());
|
|
|
|
|
inputfs.close();
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
program_ = new framework::ProgramDesc(program_desc_str);
|
|
|
|
|
GenerateLoadProgram(dirname);
|
|
|
|
|
|
|
|
|
@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
|
|
|
|
|
if (var->Persistable()) {
|
|
|
|
|
if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") {
|
|
|
|
|
// There are many unreachable variables in the program
|
|
|
|
|
for (size_t i = 0; i < program_->Size(); ++i) {
|
|
|
|
|
const framework::BlockDesc& block = program_->Block(i);
|
|
|
|
|