Add capi for fluid inference api (#20092)

* add capi for fluid inference api, including AnalysisConfig, AnalysisPredictor, PaddleBuf, PaddleTensor, ZeroCopyTensor
fix-python-transpose
liu zhengxi 6 years ago committed by GitHub
parent a16e91bb89
commit 301eeb5bea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,6 +36,7 @@ else(WIN32)
endif(WIN32)
add_subdirectory(api)
add_subdirectory(capi)
if(WITH_MKLDNN)
set(mkldnn_quantizer_src ${CMAKE_CURRENT_SOURCE_DIR}/api/mkldnn_quantizer.cc)

@ -0,0 +1,10 @@
cc_library(pd_config SRCS pd_config.cc)
cc_library(pd_predictor SRCS pd_predictor.cc)
cc_library(pd_tensor SRCS pd_tensor.cc)
cc_library(pd_c_api SRCS c_api.cc)
cc_library(paddle_fluid_c SRCS c_api.cc DEPS paddle_fluid pd_config pd_predictor pd_tensor pd_c_api)
# (TODO) dll
# cc_library(paddle_fluid_c_shared SHARED SRCS c_api.cc DEPS paddle_fluid pd_config pd_predictor pd_tensor pd_c_api)
# set_target_properties(paddle_fluid_c_shared PROPERTIES OUTPUT_NAME paddle_fluid_c)

@ -0,0 +1,97 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/capi/c_api.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/inference/capi/c_api_internal.h"
using paddle::ConvertToPaddleDType;
using paddle::ConvertToPDDataType;
using paddle::ConvertToACPrecision;
extern "C" {
PD_PaddleBuf* PD_NewPaddleBuf() { return new PD_PaddleBuf; }
void PD_DeletePaddleBuf(PD_PaddleBuf* buf) {
if (buf) {
delete buf;
buf = nullptr;
}
}
void PD_PaddleBufResize(PD_PaddleBuf* buf, size_t length) {
buf->buf.Resize(length);
}
void PD_PaddleBufReset(PD_PaddleBuf* buf, void* data, size_t length) {
buf->buf.Reset(data, length);
}
bool PD_PaddleBufEmpty(PD_PaddleBuf* buf) { return buf->buf.empty(); }
void* PD_PaddleBufData(PD_PaddleBuf* buf) { return buf->buf.data(); }
size_t PD_PaddleBufLength(PD_PaddleBuf* buf) { return buf->buf.length(); }
} // extern "C"
namespace paddle {
paddle::PaddleDType ConvertToPaddleDType(PD_DataType dtype) {
switch (dtype) {
case PD_FLOAT32:
return PD_PaddleDType::FLOAT32;
case PD_INT32:
return PD_PaddleDType::INT32;
case PD_INT64:
return PD_PaddleDType::INT64;
case PD_UINT8:
return PD_PaddleDType::UINT8;
default:
CHECK(false) << "Unsupport dtype.";
return PD_PaddleDType::FLOAT32;
}
}
PD_DataType ConvertToPDDataType(PD_PaddleDType dtype) {
switch (dtype) {
case PD_PaddleDType::FLOAT32:
return PD_DataType::PD_FLOAT32;
case PD_PaddleDType::INT32:
return PD_DataType::PD_INT32;
case PD_PaddleDType::INT64:
return PD_DataType::PD_INT64;
case PD_PaddleDType::UINT8:
return PD_DataType::PD_UINT8;
default:
CHECK(false) << "Unsupport dtype.";
return PD_DataType::PD_UNKDTYPE;
}
}
PD_ACPrecision ConvertToACPrecision(Precision dtype) {
switch (dtype) {
case Precision::kFloat32:
return PD_ACPrecision::kFloat32;
case Precision::kInt8:
return PD_ACPrecision::kInt8;
case Precision::kHalf:
return PD_ACPrecision::kHalf;
default:
CHECK(false) << "Unsupport precision.";
return PD_ACPrecision::kFloat32;
}
}
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -0,0 +1,43 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_api.h"
#include "paddle/fluid/platform/enforce.h"
using PD_PaddleDType = paddle::PaddleDType;
using PD_ACPrecision = paddle::AnalysisConfig::Precision;
struct PD_AnalysisConfig {
paddle::AnalysisConfig config;
};
struct PD_Tensor {
paddle::PaddleTensor tensor;
};
struct PD_PaddleBuf {
paddle::PaddleBuf buf;
};
namespace paddle {
paddle::PaddleDType ConvertToPaddleDType(PD_DataType dtype);
PD_DataType ConvertToPDDataType(PD_PaddleDType dtype);
PD_ACPrecision ConvertToACPrecision(Precision dtype);
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,136 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <map>
#include <numeric>
#include <vector>
#include "paddle/fluid/inference/capi/c_api.h"
#include "paddle/fluid/inference/capi/c_api_internal.h"
using paddle::ConvertToPaddleDType;
using paddle::ConvertToPDDataType;
using paddle::ConvertToACPrecision;
extern "C" {
bool PD_PredictorRun(const PD_AnalysisConfig* config, PD_Tensor* inputs,
int in_size, PD_Tensor* output_data, int** out_size,
int batch_size) {
auto predictor = paddle::CreatePaddlePredictor(config->config);
std::vector<paddle::PaddleTensor> in;
for (int i = 0; i < in_size; ++i) {
in.emplace_back(inputs->tensor);
}
std::vector<paddle::PaddleTensor> out;
if (predictor->Run(in, &out, batch_size)) {
int osize = out.size();
for (int i = 0; i < osize; ++i) {
output_data[i].tensor = out[i];
}
*out_size = &osize;
return true;
}
return false;
}
bool PD_PredictorZeroCopyRun(const PD_AnalysisConfig* config,
PD_ZeroCopyData* inputs, int in_size,
PD_ZeroCopyData* output, int** out_size) {
auto predictor = paddle::CreatePaddlePredictor(config->config);
auto input_names = predictor->GetInputNames();
PADDLE_ENFORCE_EQ(
input_names.size(), in_size,
"The number of input and the number of model's input must match. ");
for (int i = 0; i < in_size; ++i) {
auto input_t = predictor->GetInputTensor(inputs[i].name);
std::vector<int> tensor_shape;
tensor_shape.assign(inputs[i].shape,
inputs[i].shape + inputs[i].shape_size);
input_t->Reshape(tensor_shape);
switch (inputs[i].dtype) {
case PD_FLOAT32:
input_t->copy_from_cpu(static_cast<float*>(inputs[i].data));
break;
case PD_INT32:
input_t->copy_from_cpu(static_cast<int32_t*>(inputs[i].data));
break;
case PD_INT64:
input_t->copy_from_cpu(static_cast<int64_t*>(inputs[i].data));
break;
case PD_UINT8:
input_t->copy_from_cpu(static_cast<uint8_t*>(inputs[i].data));
break;
default:
CHECK(false) << "Unsupport data type.";
break;
}
}
CHECK(predictor->ZeroCopyRun());
auto output_names = predictor->GetOutputNames();
int osize = output_names.size();
*out_size = &osize;
output = new PD_ZeroCopyData[osize];
for (int i = 0; i < osize; ++i) {
LOG(INFO) << 1;
output[i].name = new char[output_names[i].length() + 1];
snprintf(output[i].name, output_names[i].length() + 1, "%s",
output_names[i].c_str());
auto output_t = predictor->GetOutputTensor(output_names[i]);
output[i].dtype = ConvertToPDDataType(output_t->type());
std::vector<int> output_shape = output_t->shape();
output[i].shape = new int[output_shape.size()];
output[i].shape = output_shape.data();
output[i].shape_size = output_shape.size();
switch (output[i].dtype) {
case PD_FLOAT32: {
std::vector<float> out_data;
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
output[i].data = static_cast<void*>(out_data.data());
} break;
case PD_INT32: {
std::vector<int32_t> out_data;
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
output[i].data = static_cast<void*>(out_data.data());
} break;
case PD_INT64: {
std::vector<int64_t> out_data;
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
output[i].data = static_cast<void*>(out_data.data());
} break;
case PD_UINT8: {
std::vector<uint8_t> out_data;
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
output[i].data = static_cast<void*>(out_data.data());
} break;
default:
CHECK(false) << "Unsupport data type.";
break;
}
}
return true;
}
} // extern "C"

@ -0,0 +1,74 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/inference/capi/c_api.h"
#include "paddle/fluid/inference/capi/c_api_internal.h"
using paddle::ConvertToPaddleDType;
using paddle::ConvertToPDDataType;
using paddle::ConvertToACPrecision;
extern "C" {
// PaddleTensor
PD_Tensor* PD_NewPaddleTensor() { return new PD_Tensor; }
void PD_DeletePaddleTensor(PD_Tensor* tensor) {
if (tensor) {
delete tensor;
tensor = nullptr;
}
}
void PD_SetPaddleTensorName(PD_Tensor* tensor, char* name) {
tensor->tensor.name = std::string(name);
}
void PD_SetPaddleTensorDType(PD_Tensor* tensor, PD_DataType dtype) {
tensor->tensor.dtype = paddle::ConvertToPaddleDType(dtype);
}
void PD_SetPaddleTensorData(PD_Tensor* tensor, PD_PaddleBuf* buf) {
tensor->tensor.data = buf->buf;
}
void PD_SetPaddleTensorShape(PD_Tensor* tensor, int* shape, int size) {
tensor->tensor.shape.assign(shape, shape + size);
}
const char* PD_GetPaddleTensorName(const PD_Tensor* tensor) {
return tensor->tensor.name.c_str();
}
PD_DataType PD_GetPaddleTensorDType(const PD_Tensor* tensor) {
return ConvertToPDDataType(tensor->tensor.dtype);
}
PD_PaddleBuf* PD_GetPaddleTensorData(const PD_Tensor* tensor) {
PD_PaddleBuf* ret = PD_NewPaddleBuf();
ret->buf = tensor->tensor.data;
return ret;
}
int* PD_GetPaddleTensorShape(const PD_Tensor* tensor, int** size) {
std::vector<int> shape = tensor->tensor.shape;
int s = shape.size();
*size = &s;
return shape.data();
}
} // extern "C"

@ -283,4 +283,29 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(trt_cascade_rcnn_test SRCS trt_cascade_rcnn_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
endif()
set(CAPI_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/capi_tests_models")
if (NOT EXISTS ${CAPI_MODEL_INSTALL_DIR})
inference_download_and_uncompress(${CAPI_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_inference_test_models.tar.gz")
endif()
inference_analysis_test(test_analyzer_capi SRCS analyzer_capi_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${CAPI_MODEL_INSTALL_DIR}/trt_inference_test_models)
set(CAPI_MODEL_INSTALL_PD_DIR "${INFERENCE_DEMO_INSTALL_DIR}/capi_mobilenet")
if (NOT EXISTS ${CAPI_MODEL_INSTALL_PD_DIR})
inference_download_and_uncompress(${CAPI_MODEL_INSTALL_PD_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz")
endif()
inference_analysis_test(test_analyzer_capi_pd_tensor SRCS analyzer_capi_pd_tensor_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${CAPI_MODEL_INSTALL_PD_DIR}/model)
if(WITH_MKLDNN)
inference_analysis_test(test_analyzer_capi_int SRCS analyzer_capi_int_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${INT8_DATA_DIR}/resnet50/model)
endif()

@ -0,0 +1,100 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/capi/c_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST(PD_AnalysisConfig, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_AnalysisConfig *config = PD_NewAnalysisConfig();
PD_DisableGpu(config);
PD_SetCpuMathLibraryNumThreads(config, 10);
int num_thread = PD_CpuMathLibraryNumThreads(config);
CHECK(10 == num_thread) << "NO";
PD_SwitchUseFeedFetchOps(config, false);
PD_SwitchSpecifyInputNames(config, true);
PD_SwitchIrDebug(config, true);
PD_SetModel(config, model_dir.c_str());
PD_SetOptimCacheDir(config, (FLAGS_infer_model + "/OptimCacheDir").c_str());
const char *model_dir_ = PD_ModelDir(config);
LOG(INFO) << model_dir_;
PD_EnableUseGpu(config, 100, 0);
bool use_gpu = PD_UseGpu(config);
CHECK(use_gpu) << "NO";
int device = PD_GpuDeviceId(config);
CHECK(0 == device) << "NO";
int init_size = PD_MemoryPoolInitSizeMb(config);
CHECK(100 == init_size) << "NO";
float frac = PD_FractionOfGpuMemoryForPool(config);
LOG(INFO) << frac;
PD_EnableCUDNN(config);
bool cudnn = PD_CudnnEnabled(config);
CHECK(cudnn) << "NO";
PD_SwitchIrOptim(config, true);
bool ir_optim = PD_IrOptim(config);
CHECK(ir_optim) << "NO";
PD_EnableTensorRtEngine(config);
bool trt_enable = PD_TensorrtEngineEnabled(config);
CHECK(trt_enable) << "NO";
PD_EnableNgraph(config);
bool ngraph_enable = PD_NgraphEnabled(config);
LOG(INFO) << ngraph_enable << " Ngraph";
PD_EnableMemoryOptim(config);
bool memory_optim_enable = PD_MemoryOptimEnabled(config);
CHECK(memory_optim_enable) << "NO";
PD_EnableProfile(config);
bool profiler_enable = PD_ProfileEnabled(config);
CHECK(profiler_enable) << "NO";
PD_SetInValid(config);
bool is_valid = PD_IsValid(config);
CHECK(!is_valid) << "NO";
PD_DeleteAnalysisConfig(config);
}
TEST(PD_AnalysisConfig, trt_int8) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_AnalysisConfig *config = PD_NewAnalysisConfig();
PD_EnableUseGpu(config, 100, 0);
PD_EnableTensorRtEngine(config, 1 << 20, 1, 3, Precision::kInt8, false, true);
bool trt_enable = PD_TensorrtEngineEnabled(config);
CHECK(trt_enable) << "NO";
PD_DeleteAnalysisConfig(config);
}
TEST(PD_AnalysisConfig, trt_fp16) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_AnalysisConfig *config = PD_NewAnalysisConfig();
PD_EnableUseGpu(config, 100, 0);
PD_EnableTensorRtEngine(config, 1 << 20, 1, 3, Precision::kHalf, false,
false);
bool trt_enable = PD_TensorrtEngineEnabled(config);
CHECK(trt_enable) << "NO";
PD_DeleteAnalysisConfig(config);
}
} // namespace analysis
} // namespace inference
} // namespace paddle

@ -0,0 +1,106 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <fstream>
#include <iostream>
#include <string>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/inference/capi/c_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
template <typename T>
void zero_copy_run() {
std::string model_dir = FLAGS_infer_model;
PD_AnalysisConfig *config = PD_NewAnalysisConfig();
PD_DisableGpu(config);
PD_SetCpuMathLibraryNumThreads(config, 10);
PD_SwitchUseFeedFetchOps(config, false);
PD_SwitchSpecifyInputNames(config, true);
PD_SwitchIrDebug(config, true);
PD_SetModel(config, model_dir.c_str()); //, params_file1.c_str());
bool use_feed_fetch = PD_UseFeedFetchOpsEnabled(config);
CHECK(!use_feed_fetch) << "NO";
bool specify_input_names = PD_SpecifyInputName(config);
CHECK(specify_input_names) << "NO";
const int batch_size = 1;
const int channels = 3;
const int height = 224;
const int width = 224;
T input[batch_size * channels * height * width] = {0};
int shape[4] = {batch_size, channels, height, width};
int shape_size = 4;
int in_size = 2;
int *out_size;
PD_ZeroCopyData *inputs = new PD_ZeroCopyData[2];
PD_ZeroCopyData *outputs = new PD_ZeroCopyData;
inputs[0].data = static_cast<void *>(input);
std::string nm = typeid(T).name();
if ("f" == nm) {
inputs[0].dtype = PD_FLOAT32;
} else if ("i" == nm) {
inputs[0].dtype = PD_INT32;
} else if ("x" == nm) {
inputs[0].dtype = PD_INT64;
} else if ("h" == nm) {
inputs[0].dtype = PD_UINT8;
} else {
CHECK(false) << "Unsupport dtype. ";
}
inputs[0].name = new char[6];
inputs[0].name[0] = 'i';
inputs[0].name[1] = 'm';
inputs[0].name[2] = 'a';
inputs[0].name[3] = 'g';
inputs[0].name[4] = 'e';
inputs[0].name[5] = '\0';
inputs[0].shape = shape;
inputs[0].shape_size = shape_size;
int *label = new int[1];
label[0] = 0;
inputs[1].data = static_cast<void *>(label);
inputs[1].dtype = PD_INT64;
inputs[1].name = new char[6];
inputs[1].name[0] = 'l';
inputs[1].name[1] = 'a';
inputs[1].name[2] = 'b';
inputs[1].name[3] = 'e';
inputs[1].name[4] = 'l';
inputs[1].name[5] = '\0';
int label_shape[2] = {1, 1};
int label_shape_size = 2;
inputs[1].shape = label_shape;
inputs[1].shape_size = label_shape_size;
PD_PredictorZeroCopyRun(config, inputs, in_size, outputs, &out_size);
}
TEST(PD_ZeroCopyRun, zero_copy_run) {
// zero_copy_run<int32_t>();
// zero_copy_run<int64_t>();
zero_copy_run<float>();
}
} // namespace analysis
} // namespace inference
} // namespace paddle

@ -0,0 +1,153 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/capi/c_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
void PD_run() {
PD_AnalysisConfig* config = PD_NewAnalysisConfig();
std::string prog_file = FLAGS_infer_model + "/__model__";
std::string params_file = FLAGS_infer_model + "/__params__";
PD_SetModel(config, prog_file.c_str(), params_file.c_str());
PD_SetProgFile(config, prog_file.c_str());
PD_SetParamsFile(config, params_file.c_str());
LOG(INFO) << PD_ProgFile(config);
LOG(INFO) << PD_ParamsFile(config);
PD_Tensor* input = PD_NewPaddleTensor();
PD_PaddleBuf* buf = PD_NewPaddleBuf();
LOG(INFO) << "PaddleBuf empty: " << PD_PaddleBufEmpty(buf);
int batch = 1;
int channel = 3;
int height = 300;
int width = 300;
int shape[4] = {batch, channel, height, width};
int shape_size = 4;
float* data = new float[batch * channel * height * width];
PD_PaddleBufReset(buf, static_cast<void*>(data),
sizeof(float) * (batch * channel * height * width));
char name[6] = {'i', 'm', 'a', 'g', 'e', '\0'};
PD_SetPaddleTensorName(input, name);
PD_SetPaddleTensorDType(input, PD_FLOAT32);
PD_SetPaddleTensorShape(input, shape, shape_size);
PD_SetPaddleTensorData(input, buf);
PD_Tensor* out_data = PD_NewPaddleTensor();
int* out_size;
PD_PredictorRun(config, input, 1, out_data, &out_size, 1);
LOG(INFO) << *out_size;
LOG(INFO) << PD_GetPaddleTensorName(out_data);
LOG(INFO) << PD_GetPaddleTensorDType(out_data);
PD_PaddleBuf* b = PD_GetPaddleTensorData(out_data);
LOG(INFO) << PD_PaddleBufLength(b);
float* result = static_cast<float*>(PD_PaddleBufData(b));
LOG(INFO) << *result;
PD_PaddleBufResize(b, 500);
PD_DeletePaddleTensor(input);
int* size;
PD_GetPaddleTensorShape(out_data, &size);
PD_DeletePaddleBuf(buf);
}
TEST(PD_Tensor, PD_run) { PD_run(); }
TEST(PD_Tensor, int32) {
PD_Tensor* input = PD_NewPaddleTensor();
PD_SetPaddleTensorDType(input, PD_INT32);
LOG(INFO) << PD_GetPaddleTensorDType(input);
}
TEST(PD_Tensor, int64) {
PD_Tensor* input = PD_NewPaddleTensor();
PD_SetPaddleTensorDType(input, PD_INT64);
LOG(INFO) << PD_GetPaddleTensorDType(input);
}
TEST(PD_Tensor, int8) {
PD_Tensor* input = PD_NewPaddleTensor();
PD_SetPaddleTensorDType(input, PD_UINT8);
LOG(INFO) << PD_GetPaddleTensorDType(input);
}
std::string read_file(std::string filename) {
std::ifstream file(filename);
return std::string((std::istreambuf_iterator<char>(file)),
std::istreambuf_iterator<char>());
}
void buffer_run() {
PD_AnalysisConfig* config = PD_NewAnalysisConfig();
std::string prog_file = FLAGS_infer_model + "/__model__";
std::string params_file = FLAGS_infer_model + "/__params__";
std::string prog_str = read_file(prog_file);
std::string params_str = read_file(params_file);
PD_SetModelBuffer(config, prog_str.c_str(), prog_str.size(),
params_str.c_str(), params_str.size());
LOG(INFO) << PD_ProgFile(config);
LOG(INFO) << PD_ParamsFile(config);
CHECK(PD_ModelFromMemory(config)) << "NO";
PD_Tensor* input = PD_NewPaddleTensor();
PD_PaddleBuf* buf = PD_NewPaddleBuf();
LOG(INFO) << "PaddleBuf empty: " << PD_PaddleBufEmpty(buf);
int batch = 1;
int channel = 3;
int height = 300;
int width = 300;
int shape[4] = {batch, channel, height, width};
int shape_size = 4;
float* data = new float[batch * channel * height * width];
PD_PaddleBufReset(buf, static_cast<void*>(data),
sizeof(float) * (batch * channel * height * width));
char name[6] = {'i', 'm', 'a', 'g', 'e', '\0'};
PD_SetPaddleTensorName(input, name);
PD_SetPaddleTensorDType(input, PD_FLOAT32);
PD_SetPaddleTensorShape(input, shape, shape_size);
PD_SetPaddleTensorData(input, buf);
PD_Tensor* out_data = PD_NewPaddleTensor();
int* out_size;
PD_PredictorRun(config, input, 1, out_data, &out_size, 1);
LOG(INFO) << *out_size;
LOG(INFO) << PD_GetPaddleTensorName(out_data);
LOG(INFO) << PD_GetPaddleTensorDType(out_data);
PD_PaddleBuf* b = PD_GetPaddleTensorData(out_data);
LOG(INFO) << PD_PaddleBufLength(b);
float* result = static_cast<float*>(PD_PaddleBufData(b));
LOG(INFO) << *result;
PD_PaddleBufResize(b, 500);
PD_DeletePaddleTensor(input);
PD_DeletePaddleBuf(buf);
}
TEST(SetModelBuffer, read) { buffer_run(); }
} // namespace analysis
} // namespace inference
} // namespace paddle

@ -0,0 +1,108 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <fstream>
#include <iostream>
#include <string>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/inference/capi/c_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
template <typename T>
void zero_copy_run() {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_AnalysisConfig *config = PD_NewAnalysisConfig();
PD_DisableGpu(config);
PD_SetCpuMathLibraryNumThreads(config, 10);
PD_SwitchUseFeedFetchOps(config, false);
PD_SwitchSpecifyInputNames(config, true);
PD_SwitchIrDebug(config, true);
PD_SetModel(config, model_dir.c_str()); //, params_file1.c_str());
bool use_feed_fetch = PD_UseFeedFetchOpsEnabled(config);
CHECK(!use_feed_fetch) << "NO";
bool specify_input_names = PD_SpecifyInputName(config);
CHECK(specify_input_names) << "NO";
const int batch_size = 1;
const int channels = 3;
const int height = 224;
const int width = 224;
T input[batch_size * channels * height * width] = {0};
int shape[4] = {batch_size, channels, height, width};
int shape_size = 4;
int in_size = 1;
int *out_size;
PD_ZeroCopyData *inputs = new PD_ZeroCopyData;
PD_ZeroCopyData *outputs = new PD_ZeroCopyData;
inputs->data = static_cast<void *>(input);
std::string nm = typeid(T).name();
if ("f" == nm) {
inputs->dtype = PD_FLOAT32;
} else if ("i" == nm) {
inputs->dtype = PD_INT32;
} else if ("x" == nm) {
inputs->dtype = PD_INT64;
} else if ("h" == nm) {
inputs->dtype = PD_UINT8;
} else {
CHECK(false) << "Unsupport dtype. ";
}
inputs->name = new char[2];
inputs->name[0] = 'x';
inputs->name[1] = '\0';
LOG(INFO) << inputs->name;
inputs->shape = shape;
inputs->shape_size = shape_size;
PD_PredictorZeroCopyRun(config, inputs, in_size, outputs, &out_size);
}
TEST(PD_ZeroCopyRun, zero_copy_run) { zero_copy_run<float>(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(PD_AnalysisConfig, profile_mkldnn) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_AnalysisConfig *config = PD_NewAnalysisConfig();
PD_DisableGpu(config);
PD_SetCpuMathLibraryNumThreads(config, 10);
PD_SwitchUseFeedFetchOps(config, false);
PD_SwitchSpecifyInputNames(config, true);
PD_SwitchIrDebug(config, true);
PD_EnableMKLDNN(config);
bool mkldnn_enable = PD_MkldnnEnabled(config);
CHECK(mkldnn_enable) << "NO";
PD_EnableMkldnnQuantizer(config);
bool quantizer_enable = PD_MkldnnQuantizerEnabled(config);
CHECK(quantizer_enable) << "NO";
PD_SetMkldnnCacheCapacity(config, 0);
PD_SetModel(config, model_dir.c_str());
PD_EnableAnakinEngine(config);
bool anakin_enable = PD_AnakinEngineEnabled(config);
LOG(INFO) << anakin_enable;
PD_DeleteAnalysisConfig(config);
}
#endif
} // namespace analysis
} // namespace inference
} // namespace paddle
Loading…
Cancel
Save