|
|
|
@ -14,16 +14,37 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <time.h>
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include "gflags/gflags.h"
|
|
|
|
|
#include "paddle/inference/inference.h"
|
|
|
|
|
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
|
|
|
std::string dirname =
|
|
|
|
|
"/home/work/liuyiqun/PaddlePaddle/Paddle/paddle/inference/"
|
|
|
|
|
"recognize_digits_mlp.inference.model";
|
|
|
|
|
std::vector<std::string> feed_var_names = {"x"};
|
|
|
|
|
std::vector<std::string> fetch_var_names = {"fc_2.tmp_2"};
|
|
|
|
|
paddle::InferenceEngine* desc = new paddle::InferenceEngine();
|
|
|
|
|
desc->LoadInferenceModel(dirname, feed_var_names, fetch_var_names);
|
|
|
|
|
DEFINE_string(dirname, "", "Directory of the inference model.");
|
|
|
|
|
DEFINE_string(feed_var_names, "", "Names of feeding variables");
|
|
|
|
|
DEFINE_string(fetch_var_names, "", "Names of fetching variables");
|
|
|
|
|
|
|
|
|
|
int main(int argc, char** argv) {
|
|
|
|
|
google::ParseCommandLineFlags(&argc, &argv, true);
|
|
|
|
|
if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() ||
|
|
|
|
|
FLAGS_fetch_var_names.empty()) {
|
|
|
|
|
// Example:
|
|
|
|
|
// ./example --dirname=recognize_digits_mlp.inference.model
|
|
|
|
|
// --feed_var_names="x"
|
|
|
|
|
// --fetch_var_names="fc_2.tmp_2"
|
|
|
|
|
std::cout << "Usage: ./example --dirname=path/to/your/model "
|
|
|
|
|
"--feed_var_names=x --fetch_var_names=y"
|
|
|
|
|
<< std::endl;
|
|
|
|
|
exit(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
|
|
|
|
|
std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl;
|
|
|
|
|
std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl;
|
|
|
|
|
|
|
|
|
|
std::string dirname = FLAGS_dirname;
|
|
|
|
|
std::vector<std::string> feed_var_names = {FLAGS_feed_var_names};
|
|
|
|
|
std::vector<std::string> fetch_var_names = {FLAGS_fetch_var_names};
|
|
|
|
|
|
|
|
|
|
paddle::InferenceEngine* engine = new paddle::InferenceEngine();
|
|
|
|
|
engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names);
|
|
|
|
|
|
|
|
|
|
paddle::framework::LoDTensor input;
|
|
|
|
|
srand(time(0));
|
|
|
|
@ -36,7 +57,7 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
std::vector<paddle::framework::LoDTensor> feeds;
|
|
|
|
|
feeds.push_back(input);
|
|
|
|
|
std::vector<paddle::framework::LoDTensor> fetchs;
|
|
|
|
|
desc->Execute(feeds, fetchs);
|
|
|
|
|
engine->Execute(feeds, fetchs);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < fetchs.size(); ++i) {
|
|
|
|
|
auto dims_i = fetchs[i].dims();
|
|
|
|
@ -52,5 +73,7 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
}
|
|
|
|
|
std::cout << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
delete engine;
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|