You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/inference/api/analysis_predictor_tester.cc

66 lines
2.0 KiB

// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
DEFINE_string(dirname, "", "dirname to tests.");
namespace paddle {
namespace inference {
using contrib::AnalysisConfig;
TEST(AnalysisPredictor, ZeroCopy) {
AnalysisConfig config;
config.model_dir = FLAGS_dirname + "/word2vec.inference.model";
config.use_feed_fetch_ops = false;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
auto w0 = predictor->GetInputTensor("firstw");
auto w1 = predictor->GetInputTensor("secondw");
auto w2 = predictor->GetInputTensor("thirdw");
auto w3 = predictor->GetInputTensor("forthw");
w0->Reshape({4, 1});
w1->Reshape({4, 1});
w2->Reshape({4, 1});
w3->Reshape({4, 1});
auto* w0_data = w0->mutable_data<int64_t>(PaddlePlace::kCPU);
auto* w1_data = w1->mutable_data<int64_t>(PaddlePlace::kCPU);
auto* w2_data = w2->mutable_data<int64_t>(PaddlePlace::kCPU);
auto* w3_data = w3->mutable_data<int64_t>(PaddlePlace::kCPU);
for (int i = 0; i < 4; i++) {
w0_data[i] = i;
w1_data[i] = i;
w2_data[i] = i;
w3_data[i] = i;
}
predictor->ZeroCopyRun();
auto out = predictor->GetOutputTensor("fc_1.tmp_2");
PaddlePlace place;
int size = 0;
auto* out_data = out->data<float>(&place, &size);
LOG(INFO) << "output size: " << size / sizeof(float);
LOG(INFO) << "output_data: " << out_data;
}
} // namespace inference
} // namespace paddle