|
|
|
@ -13,9 +13,22 @@
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/inference/api/api_anakin_engine.h"
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#include <cuda.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include <mkl_service.h>
|
|
|
|
|
#include <omp.h>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "framework/core/net/net.h"
|
|
|
|
|
#include "framework/operators/ops.h"
|
|
|
|
|
#include "saber/funcs/timer.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
template <typename Target>
|
|
|
|
@ -23,16 +36,24 @@ PaddleInferenceAnakinPredictor<Target>::PaddleInferenceAnakinPredictor(
|
|
|
|
|
const AnakinConfig &config) {
|
|
|
|
|
CHECK(Init(config));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
PaddleInferenceAnakinPredictor<anakin::X86>::PaddleInferenceAnakinPredictor(
|
|
|
|
|
const AnakinConfig &config) {
|
|
|
|
|
omp_set_dynamic(0);
|
|
|
|
|
omp_set_num_threads(1);
|
|
|
|
|
mkl_set_num_threads(1);
|
|
|
|
|
CHECK(Init(config));
|
|
|
|
|
}
|
|
|
|
|
template <typename Target>
|
|
|
|
|
bool PaddleInferenceAnakinPredictor<Target>::Init(const AnakinConfig &config) {
|
|
|
|
|
if (!(graph_.load(config.model_file))) {
|
|
|
|
|
LOG(FATAL) << "fail to load graph from " << config.model_file;
|
|
|
|
|
VLOG(3) << "fail to load graph from " << config.model_file;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto inputs = graph_.get_ins();
|
|
|
|
|
for (auto &input_str : inputs) {
|
|
|
|
|
graph_.ResetBatchSize(input_str, config.max_batch_size);
|
|
|
|
|
max_batch_size_ = config.max_batch_size;
|
|
|
|
|
}
|
|
|
|
|
// optimization for graph
|
|
|
|
|
if (!(graph_.Optimize())) {
|
|
|
|
@ -52,15 +73,15 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
|
|
|
|
|
std::vector<PaddleTensor> *output_data, int batch_size) {
|
|
|
|
|
for (const auto &input : inputs) {
|
|
|
|
|
if (input.dtype != PaddleDType::FLOAT32) {
|
|
|
|
|
LOG(ERROR) << "Only support float type inputs. " << input.name
|
|
|
|
|
<< "'s type is not float";
|
|
|
|
|
VLOG(3) << "Only support float type inputs. " << input.name
|
|
|
|
|
<< "'s type is not float";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto d_tensor_in_p = executor_p_->get_in(input.name);
|
|
|
|
|
auto net_shape = d_tensor_in_p->valid_shape();
|
|
|
|
|
auto net_shape = d_tensor_in_p->shape();
|
|
|
|
|
if (net_shape.size() != input.shape.size()) {
|
|
|
|
|
LOG(ERROR) << " input " << input.name
|
|
|
|
|
<< "'s shape size should be equal to that of net";
|
|
|
|
|
VLOG(3) << " input " << input.name
|
|
|
|
|
<< "'s shape size should be equal to that of net";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
int sum = 1;
|
|
|
|
@ -79,21 +100,45 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
|
|
|
|
|
}
|
|
|
|
|
d_tensor_in_p->reshape(tmp_shape);
|
|
|
|
|
|
|
|
|
|
if (input.lod.size() > 0) {
|
|
|
|
|
if (input.lod.size() > 1) {
|
|
|
|
|
VLOG(3) << " input lod first dim should <=1, but you set "
|
|
|
|
|
<< input.lod.size();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> offset(input.lod[0].begin(), input.lod[0].end());
|
|
|
|
|
d_tensor_in_p->set_seq_offset(offset);
|
|
|
|
|
VLOG(3) << "offset.size(): " << offset.size();
|
|
|
|
|
for (int i = 0; i < offset.size(); i++) {
|
|
|
|
|
VLOG(3) << offset[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float *d_data_p = d_tensor_in_p->mutable_data();
|
|
|
|
|
if (cudaMemcpy(d_data_p, static_cast<float *>(input.data.data()),
|
|
|
|
|
d_tensor_in_p->valid_size() * sizeof(float),
|
|
|
|
|
cudaMemcpyHostToDevice) != 0) {
|
|
|
|
|
LOG(ERROR) << "copy data from CPU to GPU error";
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (std::is_same<anakin::NV, Target>::value) {
|
|
|
|
|
if (cudaMemcpy(d_data_p, static_cast<float *>(input.data.data()),
|
|
|
|
|
d_tensor_in_p->valid_size() * sizeof(float),
|
|
|
|
|
cudaMemcpyHostToDevice) != 0) {
|
|
|
|
|
VLOG(3) << "copy data from CPU to GPU error";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
if (std::is_same<anakin::X86, Target>::value) {
|
|
|
|
|
memcpy(d_data_p, static_cast<float *>(input.data.data()),
|
|
|
|
|
d_tensor_in_p->valid_size() * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
cudaStreamSynchronize(NULL);
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
|
executor_p_->prediction();
|
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
if (output_data->empty()) {
|
|
|
|
|
LOG(ERROR) << "At least one output should be set with tensors' names.";
|
|
|
|
|
VLOG(3) << "At least one output should be set with tensors' names.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (auto &output : *output_data) {
|
|
|
|
@ -102,14 +147,22 @@ bool PaddleInferenceAnakinPredictor<Target>::Run(
|
|
|
|
|
if (output.data.length() < tensor->valid_size() * sizeof(float)) {
|
|
|
|
|
output.data.Resize(tensor->valid_size() * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
// Copy data from GPU -> CPU
|
|
|
|
|
if (cudaMemcpy(output.data.data(), tensor->mutable_data(),
|
|
|
|
|
tensor->valid_size() * sizeof(float),
|
|
|
|
|
cudaMemcpyDeviceToHost) != 0) {
|
|
|
|
|
LOG(ERROR) << "copy data from GPU to CPU error";
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
#if PADDLE_WITH_CUDA
|
|
|
|
|
if (std::is_same<anakin::NV, Target>::value) {
|
|
|
|
|
// Copy data from GPU -> CPU
|
|
|
|
|
if (cudaMemcpy(output.data.data(), tensor->mutable_data(),
|
|
|
|
|
tensor->valid_size() * sizeof(float),
|
|
|
|
|
cudaMemcpyDeviceToHost) != 0) {
|
|
|
|
|
VLOG(3) << "copy data from GPU to CPU error";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
if (std::is_same<anakin::X86, Target>::value) {
|
|
|
|
|
memcpy(output.data.data(), tensor->mutable_data(),
|
|
|
|
|
tensor->valid_size() * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
cudaStreamSynchronize(NULL);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
@ -132,7 +185,7 @@ PaddleInferenceAnakinPredictor<Target>::Clone() {
|
|
|
|
|
auto anakin_predictor_p =
|
|
|
|
|
dynamic_cast<PaddleInferenceAnakinPredictor<Target> *>(cls.get());
|
|
|
|
|
if (!anakin_predictor_p) {
|
|
|
|
|
LOG(ERROR) << "fail to call Init";
|
|
|
|
|
VLOG(3) << "fail to call Init";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
anakin_predictor_p->get_executer().init(graph_);
|
|
|
|
@ -162,6 +215,44 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
|
|
|
|
|
VLOG(3) << "Anakin Predictor create on unknown platform.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
|
|
|
|
|
template <typename Target>
|
|
|
|
|
using executor_t =
|
|
|
|
|
anakin::Net<Target, anakin::saber::AK_FLOAT, anakin::Precision::FP32>;
|
|
|
|
|
|
|
|
|
|
template <typename Target>
|
|
|
|
|
void DisplayOpTimer(executor_t<Target> *net_executor, int epoch) {
|
|
|
|
|
std::vector<float> op_time = net_executor->get_op_time();
|
|
|
|
|
auto exec_funcs = net_executor->get_exec_funcs();
|
|
|
|
|
auto op_param = net_executor->get_op_param();
|
|
|
|
|
for (int i = 0; i < op_time.size(); i++) {
|
|
|
|
|
LOG(INFO) << "name: " << exec_funcs[i].name
|
|
|
|
|
<< " op_type: " << exec_funcs[i].op_name
|
|
|
|
|
<< " op_param: " << op_param[i] << " time " << op_time[i] / epoch;
|
|
|
|
|
}
|
|
|
|
|
std::map<std::string, float> op_map;
|
|
|
|
|
for (int i = 0; i < op_time.size(); i++) {
|
|
|
|
|
auto it = op_map.find(op_param[i]);
|
|
|
|
|
if (it != op_map.end())
|
|
|
|
|
op_map[op_param[i]] += op_time[i];
|
|
|
|
|
else
|
|
|
|
|
op_map.insert(std::pair<std::string, float>(op_param[i], op_time[i]));
|
|
|
|
|
}
|
|
|
|
|
for (auto it = op_map.begin(); it != op_map.end(); ++it) {
|
|
|
|
|
LOG(INFO) << it->first << " " << (it->second) / epoch << " ms";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename Target>
|
|
|
|
|
PaddleInferenceAnakinPredictor<Target>::~PaddleInferenceAnakinPredictor() {
|
|
|
|
|
#ifdef PADDLE_ANAKIN_ENABLE_OP_TIMER
|
|
|
|
|
DisplayOpTimer<Target>(executor_p_, max_batch_size_);
|
|
|
|
|
#endif
|
|
|
|
|
delete executor_p_;
|
|
|
|
|
executor_p_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace paddle
|
|
|
|
|