|
|
|
@ -93,15 +93,15 @@ void CheckError(const paddle::framework::LoDTensor& output1,
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<paddle::framework::ProgramDesc> InitProgram(
|
|
|
|
|
paddle::framework::Executor* executor, paddle::framework::Scope* scope,
|
|
|
|
|
const std::string& dirname, const bool is_combined = false) {
|
|
|
|
|
const std::string& dirname, const bool is_combined = false,
|
|
|
|
|
const std::string& prog_filename = "__model_combined__",
|
|
|
|
|
const std::string& param_filename = "__params_combined__") {
|
|
|
|
|
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
|
|
|
|
|
if (is_combined) {
|
|
|
|
|
// All parameters are saved in a single file.
|
|
|
|
|
// Hard-coding the file names of program and parameters in unittest.
|
|
|
|
|
// The file names should be consistent with that used in Python API
|
|
|
|
|
// `fluid.io.save_inference_model`.
|
|
|
|
|
std::string prog_filename = "model";
|
|
|
|
|
std::string param_filename = "params";
|
|
|
|
|
inference_program =
|
|
|
|
|
paddle::inference::Load(executor, scope, dirname + "/" + prog_filename,
|
|
|
|
|
dirname + "/" + param_filename);
|
|
|
|
@ -114,12 +114,15 @@ std::unique_ptr<paddle::framework::ProgramDesc> InitProgram(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int64_t>> GetFeedTargetShapes(
|
|
|
|
|
const std::string& dirname, const bool is_combined = false) {
|
|
|
|
|
const std::string& dirname, const bool is_combined = false,
|
|
|
|
|
const std::string& prog_filename = "__model_combined__",
|
|
|
|
|
const std::string& param_filename = "__params_combined__") {
|
|
|
|
|
auto place = paddle::platform::CPUPlace();
|
|
|
|
|
auto executor = paddle::framework::Executor(place);
|
|
|
|
|
auto* scope = new paddle::framework::Scope();
|
|
|
|
|
|
|
|
|
|
auto inference_program = InitProgram(&executor, scope, dirname, is_combined);
|
|
|
|
|
auto inference_program = InitProgram(&executor, scope, dirname, is_combined,
|
|
|
|
|
prog_filename, param_filename);
|
|
|
|
|
auto& global_block = inference_program->Block(0);
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& feed_target_names =
|
|
|
|
|