You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
3.2 KiB
110 lines
3.2 KiB
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
#include <stdio.h>
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <typeinfo>
|
|
#include <vector>
|
|
#include "paddle/fluid/inference/capi/c_api.h"
|
|
#include "paddle/fluid/inference/tests/api/tester_helper.h"
|
|
|
|
namespace paddle {
|
|
namespace inference {
|
|
namespace analysis {
|
|
|
|
template <typename T>
|
|
void zero_copy_run() {
|
|
std::string model_dir = FLAGS_infer_model;
|
|
PD_AnalysisConfig *config = PD_NewAnalysisConfig();
|
|
PD_DisableGpu(config);
|
|
PD_SetCpuMathLibraryNumThreads(config, 10);
|
|
PD_SwitchUseFeedFetchOps(config, false);
|
|
PD_SwitchSpecifyInputNames(config, true);
|
|
PD_SwitchIrDebug(config, true);
|
|
PD_SetModel(config, model_dir.c_str()); //, params_file1.c_str());
|
|
bool use_feed_fetch = PD_UseFeedFetchOpsEnabled(config);
|
|
CHECK(!use_feed_fetch) << "NO";
|
|
bool specify_input_names = PD_SpecifyInputName(config);
|
|
CHECK(specify_input_names) << "NO";
|
|
|
|
const int batch_size = 1;
|
|
const int channels = 3;
|
|
const int height = 224;
|
|
const int width = 224;
|
|
T input[batch_size * channels * height * width] = {0};
|
|
int shape[4] = {batch_size, channels, height, width};
|
|
int shape_size = 4;
|
|
int in_size = 2;
|
|
int *out_size;
|
|
PD_ZeroCopyData *inputs = new PD_ZeroCopyData[2];
|
|
PD_ZeroCopyData *outputs = new PD_ZeroCopyData;
|
|
inputs[0].data = static_cast<void *>(input);
|
|
std::string nm = typeid(T).name();
|
|
if ("f" == nm) {
|
|
inputs[0].dtype = PD_FLOAT32;
|
|
} else if ("i" == nm) {
|
|
inputs[0].dtype = PD_INT32;
|
|
} else if ("x" == nm) {
|
|
inputs[0].dtype = PD_INT64;
|
|
} else if ("h" == nm) {
|
|
inputs[0].dtype = PD_UINT8;
|
|
} else {
|
|
CHECK(false) << "Unsupport dtype. ";
|
|
}
|
|
inputs[0].name = new char[6];
|
|
inputs[0].name[0] = 'i';
|
|
inputs[0].name[1] = 'm';
|
|
inputs[0].name[2] = 'a';
|
|
inputs[0].name[3] = 'g';
|
|
inputs[0].name[4] = 'e';
|
|
inputs[0].name[5] = '\0';
|
|
inputs[0].shape = shape;
|
|
inputs[0].shape_size = shape_size;
|
|
|
|
int *label = new int[1];
|
|
label[0] = 0;
|
|
inputs[1].data = static_cast<void *>(label);
|
|
inputs[1].dtype = PD_INT64;
|
|
inputs[1].name = new char[6];
|
|
inputs[1].name[0] = 'l';
|
|
inputs[1].name[1] = 'a';
|
|
inputs[1].name[2] = 'b';
|
|
inputs[1].name[3] = 'e';
|
|
inputs[1].name[4] = 'l';
|
|
inputs[1].name[5] = '\0';
|
|
int label_shape[2] = {1, 1};
|
|
int label_shape_size = 2;
|
|
inputs[1].shape = label_shape;
|
|
inputs[1].shape_size = label_shape_size;
|
|
|
|
PD_PredictorZeroCopyRun(config, inputs, in_size, &outputs, &out_size);
|
|
|
|
LOG(INFO) << outputs[0].name;
|
|
LOG(INFO) << outputs[0].shape_size;
|
|
}
|
|
|
|
TEST(PD_ZeroCopyRun, zero_copy_run) {
|
|
// zero_copy_run<int32_t>();
|
|
// zero_copy_run<int64_t>();
|
|
zero_copy_run<float>();
|
|
}
|
|
|
|
} // namespace analysis
|
|
} // namespace inference
|
|
} // namespace paddle
|