|
|
|
@ -30,15 +30,28 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
|
|
|
|
|
if (!argument->scope_valid()) {
|
|
|
|
|
argument->SetScope(new framework::Scope);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(argument->use_gpu_valid());
|
|
|
|
|
|
|
|
|
|
// The load program should run on the same device with the inference program,
|
|
|
|
|
// so that the parameters will on the same device, or they will keep copying
|
|
|
|
|
// between difference devices.
|
|
|
|
|
platform::Place place;
|
|
|
|
|
if (argument->use_gpu()) {
|
|
|
|
|
PADDLE_ENFORCE(argument->gpu_device_id_valid());
|
|
|
|
|
place = platform::CUDAPlace(argument->gpu_device_id());
|
|
|
|
|
} else {
|
|
|
|
|
place = platform::CPUPlace();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (argument->model_dir_valid()) {
|
|
|
|
|
auto program = LoadModel(argument->model_dir(), argument->scope_ptr());
|
|
|
|
|
auto program =
|
|
|
|
|
LoadModel(argument->model_dir(), argument->scope_ptr(), place);
|
|
|
|
|
argument->SetMainProgram(program.release());
|
|
|
|
|
} else if (argument->model_program_path_valid() &&
|
|
|
|
|
argument->model_params_path_valid()) {
|
|
|
|
|
auto program =
|
|
|
|
|
LoadModel(argument->model_program_path(), argument->model_params_path(),
|
|
|
|
|
argument->scope_ptr());
|
|
|
|
|
argument->scope_ptr(), place);
|
|
|
|
|
argument->SetMainProgram(program.release());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
@ -52,16 +65,15 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
|
|
|
|
|
const std::string &path, framework::Scope *scope) {
|
|
|
|
|
platform::CPUPlace place;
|
|
|
|
|
const std::string &path, framework::Scope *scope,
|
|
|
|
|
const platform::Place &place) {
|
|
|
|
|
framework::Executor exe(place);
|
|
|
|
|
return Load(&exe, scope, path);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
|
|
|
|
|
const std::string &program_path, const std::string ¶ms_path,
|
|
|
|
|
framework::Scope *scope) {
|
|
|
|
|
platform::CPUPlace place;
|
|
|
|
|
framework::Scope *scope, const platform::Place &place) {
|
|
|
|
|
framework::Executor exe(place);
|
|
|
|
|
return Load(&exe, scope, program_path, params_path);
|
|
|
|
|
}
|
|
|
|
|