|
|
|
@ -69,7 +69,8 @@ bool IsPersistable(const framework::VarDesc* var) {
|
|
|
|
|
void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
|
|
|
|
|
const framework::ProgramDesc& main_program,
|
|
|
|
|
const std::string& dirname,
|
|
|
|
|
const std::string& param_filename) {
|
|
|
|
|
const std::string& param_filename,
|
|
|
|
|
bool model_from_memory = false) {
|
|
|
|
|
const framework::BlockDesc& global_block = main_program.Block(0);
|
|
|
|
|
|
|
|
|
|
framework::ProgramDesc* load_program = new framework::ProgramDesc();
|
|
|
|
@ -108,6 +109,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
|
|
|
|
|
op->SetType("load_combine");
|
|
|
|
|
op->SetOutput("Out", paramlist);
|
|
|
|
|
op->SetAttr("file_path", {param_filename});
|
|
|
|
|
op->SetAttr("model_from_memory", {model_from_memory});
|
|
|
|
|
op->CheckAttrs();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -130,16 +132,17 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
|
|
|
|
|
"model version %ld is not supported.",
|
|
|
|
|
main_program->Version());
|
|
|
|
|
|
|
|
|
|
LoadPersistables(executor, scope, *main_program, dirname, "");
|
|
|
|
|
// model_from_memory is false in seperate parameters.
|
|
|
|
|
LoadPersistables(executor, scope, *main_program, dirname, "",
|
|
|
|
|
false /* model_from_memory */);
|
|
|
|
|
return main_program;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> Load(
|
|
|
|
|
framework::Executor* executor, framework::Scope* scope,
|
|
|
|
|
const std::string& prog_filename, const std::string& param_filename) {
|
|
|
|
|
std::string model_filename = prog_filename;
|
|
|
|
|
std::string program_desc_str;
|
|
|
|
|
ReadBinaryFile(model_filename, &program_desc_str);
|
|
|
|
|
ReadBinaryFile(prog_filename, &program_desc_str);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> main_program(
|
|
|
|
|
new framework::ProgramDesc(program_desc_str));
|
|
|
|
@ -147,7 +150,22 @@ std::unique_ptr<framework::ProgramDesc> Load(
|
|
|
|
|
"model version %ld is not supported.",
|
|
|
|
|
main_program->Version());
|
|
|
|
|
|
|
|
|
|
LoadPersistables(executor, scope, *main_program, "", param_filename);
|
|
|
|
|
LoadPersistables(executor, scope, *main_program, "", param_filename,
|
|
|
|
|
false /* model_from_memory */);
|
|
|
|
|
return main_program;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
|
|
|
|
|
framework::Executor* executor, framework::Scope* scope,
|
|
|
|
|
const std::string& prog_buffer, const std::string& param_buffer) {
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> main_program(
|
|
|
|
|
new framework::ProgramDesc(prog_buffer));
|
|
|
|
|
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
|
|
|
|
|
"model version %ld is not supported.",
|
|
|
|
|
main_program->Version());
|
|
|
|
|
|
|
|
|
|
LoadPersistables(executor, scope, *main_program, "", param_buffer,
|
|
|
|
|
true /* model_filename */);
|
|
|
|
|
return main_program;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|