|
|
|
@ -70,7 +70,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
|
|
|
|
|
const framework::ProgramDesc& main_program,
|
|
|
|
|
const std::string& dirname,
|
|
|
|
|
const std::string& param_filename,
|
|
|
|
|
bool is_memory_load = false) {
|
|
|
|
|
bool model_from_memory = false) {
|
|
|
|
|
const framework::BlockDesc& global_block = main_program.Block(0);
|
|
|
|
|
|
|
|
|
|
framework::ProgramDesc* load_program = new framework::ProgramDesc();
|
|
|
|
@ -109,7 +109,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
|
|
|
|
|
op->SetType("load_combine");
|
|
|
|
|
op->SetOutput("Out", paramlist);
|
|
|
|
|
op->SetAttr("file_path", {param_filename});
|
|
|
|
|
op->SetAttr("is_memory_load", {is_memory_load});
|
|
|
|
|
op->SetAttr("model_from_memory", {model_from_memory});
|
|
|
|
|
op->CheckAttrs();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -132,23 +132,17 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
|
|
|
|
|
"model version %ld is not supported.",
|
|
|
|
|
main_program->Version());
|
|
|
|
|
|
|
|
|
|
// is_memory_load is false in seperate parameters.
|
|
|
|
|
// model_from_memory is false in seperate parameters.
|
|
|
|
|
LoadPersistables(executor, scope, *main_program, dirname, "",
|
|
|
|
|
false /* is_memory_load */);
|
|
|
|
|
false /* model_from_memory */);
|
|
|
|
|
return main_program;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
|
|
|
|
|
framework::Scope* scope,
|
|
|
|
|
const std::string& prog_filename,
|
|
|
|
|
const std::string& param_filename,
|
|
|
|
|
bool is_memory_load = false) {
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> Load(
|
|
|
|
|
framework::Executor* executor, framework::Scope* scope,
|
|
|
|
|
const std::string& prog_filename, const std::string& param_filename) {
|
|
|
|
|
std::string program_desc_str;
|
|
|
|
|
if (!is_memory_load) {
|
|
|
|
|
ReadBinaryFile(prog_filename, &program_desc_str);
|
|
|
|
|
} else {
|
|
|
|
|
program_desc_str = prog_filename;
|
|
|
|
|
}
|
|
|
|
|
ReadBinaryFile(prog_filename, &program_desc_str);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> main_program(
|
|
|
|
|
new framework::ProgramDesc(program_desc_str));
|
|
|
|
@ -157,15 +151,22 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
|
|
|
|
|
main_program->Version());
|
|
|
|
|
|
|
|
|
|
LoadPersistables(executor, scope, *main_program, "", param_filename,
|
|
|
|
|
is_memory_load);
|
|
|
|
|
false /* model_from_memory */);
|
|
|
|
|
return main_program;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> Load(
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
|
|
|
|
|
framework::Executor* executor, framework::Scope* scope,
|
|
|
|
|
const std::string& prog_filename, const std::string& param_filename) {
|
|
|
|
|
return Load(executor, scope, prog_filename, param_filename,
|
|
|
|
|
false /* is_memory_load */);
|
|
|
|
|
const std::string& prog_buffer, const std::string& param_buffer) {
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> main_program(
|
|
|
|
|
new framework::ProgramDesc(prog_buffer));
|
|
|
|
|
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
|
|
|
|
|
"model version %ld is not supported.",
|
|
|
|
|
main_program->Version());
|
|
|
|
|
|
|
|
|
|
LoadPersistables(executor, scope, *main_program, "", param_buffer,
|
|
|
|
|
true /* model_filename */);
|
|
|
|
|
return main_program;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SaveVars(const framework::Scope& scope,
|
|
|
|
|