fea/init tensorrt engine (#10003)
parent
64babc9aeb
commit
2d57158e2b
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/framework.pb.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* EngineBase is the base class of all inference engines. An inference engine
|
||||||
|
* takes a paddle program as input, and outputs the result in fluid Tensor
|
||||||
|
* format. It can be used to optimize performance of computation sub-blocks, for
|
||||||
|
* example, break down the original block into sub-blocks and execute each
|
||||||
|
* sub-blocks in different engines.
|
||||||
|
*
|
||||||
|
* For example:
|
||||||
|
* When inference, the resnet50 model can put most of the model into subgraph
|
||||||
|
* and run it on a TensorRT engine.
|
||||||
|
*
|
||||||
|
* There are several engines such as TensorRT and other frameworks, so an
|
||||||
|
* EngineBase is put forward to give an unified interface for all the
|
||||||
|
* different engine implemention.
|
||||||
|
*/
|
||||||
|
class EngineBase {
|
||||||
|
public:
|
||||||
|
using DescType = ::paddle::framework::proto::BlockDesc;
|
||||||
|
|
||||||
|
// Build the model and do some preparation, for example, in TensorRT, run
|
||||||
|
// createInferBuilder, buildCudaEngine.
|
||||||
|
virtual void Build(const DescType& paddle_model) = 0;
|
||||||
|
|
||||||
|
// Execute the engine, that will run the inference network.
|
||||||
|
virtual void Execute(int batch_size) = 0;
|
||||||
|
|
||||||
|
virtual ~EngineBase() {}
|
||||||
|
|
||||||
|
}; // class EngineBase
|
||||||
|
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -1 +1,4 @@
|
|||||||
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
|
if(WITH_TESTING)
|
||||||
|
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
|
||||||
|
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
|
||||||
|
endif()
|
||||||
|
@ -0,0 +1,134 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||||
|
|
||||||
|
#include <NvInfer.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <glog/logging.h>
|
||||||
|
#include "paddle/fluid/inference/tensorrt/helper.h"
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
void TensorRTEngine::Build(const DescType& paddle_model) {
|
||||||
|
PADDLE_ENFORCE(false, "not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorRTEngine::Execute(int batch_size) {
|
||||||
|
infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr);
|
||||||
|
cudaStreamSynchronize(*stream_);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorRTEngine::~TensorRTEngine() {
|
||||||
|
// clean buffer
|
||||||
|
for (auto& buffer : buffers_) {
|
||||||
|
if (buffer != nullptr) {
|
||||||
|
PADDLE_ENFORCE_EQ(0, cudaFree(buffer));
|
||||||
|
buffer = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorRTEngine::FreezeNetwork() {
|
||||||
|
PADDLE_ENFORCE(infer_builder_ != nullptr,
|
||||||
|
"Call InitNetwork first to initialize network.");
|
||||||
|
PADDLE_ENFORCE(infer_network_ != nullptr,
|
||||||
|
"Call InitNetwork first to initialize network.");
|
||||||
|
// build engine.
|
||||||
|
infer_builder_->setMaxBatchSize(max_batch_);
|
||||||
|
infer_builder_->setMaxWorkspaceSize(max_workspace_);
|
||||||
|
|
||||||
|
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
|
||||||
|
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
|
||||||
|
|
||||||
|
infer_context_.reset(infer_engine_->createExecutionContext());
|
||||||
|
|
||||||
|
// allocate GPU buffers.
|
||||||
|
buffers_.resize(buffer_sizes_.size(), nullptr);
|
||||||
|
for (auto& item : buffer_sizes_) {
|
||||||
|
if (item.second == 0) {
|
||||||
|
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
|
||||||
|
item.second = kDataTypeSize[static_cast<int>(
|
||||||
|
infer_engine_->getBindingDataType(slot_offset))] *
|
||||||
|
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
|
||||||
|
}
|
||||||
|
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
|
||||||
|
nvinfer1::DataType dtype,
|
||||||
|
const nvinfer1::Dims& dim) {
|
||||||
|
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
|
||||||
|
name);
|
||||||
|
|
||||||
|
PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
|
||||||
|
auto* input = infer_network_->addInput(name.c_str(), dtype, dim);
|
||||||
|
PADDLE_ENFORCE(input, "infer network add input %s failed", name);
|
||||||
|
|
||||||
|
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim);
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
|
||||||
|
const std::string& name) {
|
||||||
|
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
|
||||||
|
name);
|
||||||
|
|
||||||
|
auto* output = layer->getOutput(offset);
|
||||||
|
PADDLE_ENFORCE(output != nullptr);
|
||||||
|
output->setName(name.c_str());
|
||||||
|
infer_network_->markOutput(*output);
|
||||||
|
// output buffers' size can only be decided latter, set zero here to mark this
|
||||||
|
// and will reset latter.
|
||||||
|
buffer_sizes_[name] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
|
||||||
|
return buffer(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
|
||||||
|
size_t max_size) {
|
||||||
|
// determine data size
|
||||||
|
auto it = buffer_sizes_.find(name);
|
||||||
|
PADDLE_ENFORCE(it != buffer_sizes_.end());
|
||||||
|
PADDLE_ENFORCE_GT(it->second, 0);
|
||||||
|
PADDLE_ENFORCE_GE(max_size, it->second);
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second,
|
||||||
|
cudaMemcpyDeviceToHost, *stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void*& TensorRTEngine::buffer(const std::string& name) {
|
||||||
|
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
|
||||||
|
auto it = buffer_sizes_.find(name);
|
||||||
|
PADDLE_ENFORCE(it != buffer_sizes_.end());
|
||||||
|
auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
|
||||||
|
return buffers_[slot_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
|
||||||
|
size_t size) {
|
||||||
|
void* buf = buffer(name);
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,144 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <NvInfer.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "paddle/fluid/inference/engine.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* TensorRT Engine.
|
||||||
|
*
|
||||||
|
* There are two alternative ways to use it, one is to build from a paddle
|
||||||
|
* protobuf model, another way is to manully construct the network.
|
||||||
|
*/
|
||||||
|
class TensorRTEngine : public EngineBase {
|
||||||
|
public:
|
||||||
|
// Weight is model parameter.
|
||||||
|
class Weight {
|
||||||
|
public:
|
||||||
|
Weight(nvinfer1::DataType dtype, void* value, int num_elem) {
|
||||||
|
w_.type = dtype;
|
||||||
|
w_.values = value;
|
||||||
|
w_.count = num_elem;
|
||||||
|
}
|
||||||
|
const nvinfer1::Weights& get() { return w_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
nvinfer1::Weights w_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream,
|
||||||
|
nvinfer1::ILogger& logger = NaiveLogger::Global())
|
||||||
|
: max_batch_(max_batch),
|
||||||
|
max_workspace_(max_workspace),
|
||||||
|
stream_(stream),
|
||||||
|
logger_(logger) {}
|
||||||
|
|
||||||
|
virtual ~TensorRTEngine();
|
||||||
|
|
||||||
|
// TODO(Superjomn) implement it later when graph segmentation is supported.
|
||||||
|
virtual void Build(const DescType& paddle_model) override;
|
||||||
|
|
||||||
|
virtual void Execute(int batch_size) override;
|
||||||
|
|
||||||
|
// Initialize the inference network, so that TensorRT layers can add to this
|
||||||
|
// network.
|
||||||
|
void InitNetwork() {
|
||||||
|
infer_builder_.reset(createInferBuilder(logger_));
|
||||||
|
infer_network_.reset(infer_builder_->createNetwork());
|
||||||
|
}
|
||||||
|
// After finishing adding ops, freeze this network and creates the executation
|
||||||
|
// environment.
|
||||||
|
void FreezeNetwork();
|
||||||
|
|
||||||
|
// Add an input and set its name, data type and dimention.
|
||||||
|
nvinfer1::ITensor* DeclareInput(const std::string& name,
|
||||||
|
nvinfer1::DataType dtype,
|
||||||
|
const nvinfer1::Dims& dim);
|
||||||
|
// Set the offset-th output from a layer as the network's output, and set its
|
||||||
|
// name.
|
||||||
|
void DeclareOutput(const nvinfer1::ILayer* layer, int offset,
|
||||||
|
const std::string& name);
|
||||||
|
|
||||||
|
// GPU memory address for an ITensor with specific name. One can operate on
|
||||||
|
// these memory directly for acceleration, for example, output the converted
|
||||||
|
// data directly to the buffer to save data copy overhead.
|
||||||
|
// NOTE this should be used after calling `FreezeNetwork`.
|
||||||
|
void*& buffer(const std::string& name);
|
||||||
|
|
||||||
|
// Fill an input from CPU memory with name and size.
|
||||||
|
void SetInputFromCPU(const std::string& name, void* data, size_t size);
|
||||||
|
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be
|
||||||
|
// accessed directly. Fill an input from GPU memory with name and size.
|
||||||
|
void SetInputFromGPU(const std::string& name, void* data, size_t size);
|
||||||
|
// Get an output called name, the output of tensorrt is in GPU, so this method
|
||||||
|
// will just return the output's GPU memory address.
|
||||||
|
void* GetOutputInGPU(const std::string& name);
|
||||||
|
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
|
||||||
|
// to CPU.
|
||||||
|
void GetOutputInCPU(const std::string& name, void* dst, size_t max_size);
|
||||||
|
|
||||||
|
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
|
||||||
|
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// the max batch size
|
||||||
|
int max_batch_;
|
||||||
|
// the max memory size the engine uses
|
||||||
|
int max_workspace_;
|
||||||
|
cudaStream_t* stream_;
|
||||||
|
nvinfer1::ILogger& logger_;
|
||||||
|
|
||||||
|
std::vector<void*> buffers_;
|
||||||
|
// max data size for the buffers.
|
||||||
|
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
|
||||||
|
|
||||||
|
// TensorRT related internal members
|
||||||
|
template <typename T>
|
||||||
|
struct Destroyer {
|
||||||
|
void operator()(T* x) { x->destroy(); }
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
using infer_ptr = std::unique_ptr<T, Destroyer<T>>;
|
||||||
|
infer_ptr<nvinfer1::IBuilder> infer_builder_;
|
||||||
|
infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
|
||||||
|
infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
|
||||||
|
infer_ptr<nvinfer1::IExecutionContext> infer_context_;
|
||||||
|
}; // class TensorRTEngine
|
||||||
|
|
||||||
|
// Add an layer__ into engine__ with args ARGS.
|
||||||
|
// For example:
|
||||||
|
// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias)
|
||||||
|
//
|
||||||
|
// Reference
|
||||||
|
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network
|
||||||
|
//
|
||||||
|
// will add a fully connected layer into the engine.
|
||||||
|
// TensorRT has too many layers, so that is not wise to add member functions for
|
||||||
|
// them, and an macro like this is more extensible when underlying TensorRT
|
||||||
|
// library add new layer supports.
|
||||||
|
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
|
||||||
|
engine__->network()->add##layer__(ARGS);
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <NvInfer.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <glog/logging.h>
|
||||||
|
#include "paddle/fluid/platform/dynload/tensorrt.h"
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
namespace dy = paddle::platform::dynload;
|
||||||
|
|
||||||
|
static size_t AccumDims(nvinfer1::Dims dims) {
|
||||||
|
size_t num = dims.nbDims == 0 ? 0 : 1;
|
||||||
|
for (int i = 0; i < dims.nbDims; i++) {
|
||||||
|
PADDLE_ENFORCE_GT(dims.d[i], 0);
|
||||||
|
num *= dims.d[i];
|
||||||
|
}
|
||||||
|
return num;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TensorRT data type to size
|
||||||
|
const int kDataTypeSize[] = {
|
||||||
|
4, // kFLOAT
|
||||||
|
2, // kHALF
|
||||||
|
1, // kINT8
|
||||||
|
4 // kINT32
|
||||||
|
};
|
||||||
|
|
||||||
|
// The following two API are implemented in TensorRT's header file, cannot load
|
||||||
|
// from the dynamic library. So create our own implementation and directly
|
||||||
|
// trigger the method from the dynamic library.
|
||||||
|
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
|
||||||
|
return static_cast<nvinfer1::IBuilder*>(
|
||||||
|
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
|
||||||
|
}
|
||||||
|
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
|
||||||
|
return static_cast<nvinfer1::IRuntime*>(
|
||||||
|
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
|
||||||
|
}
|
||||||
|
|
||||||
|
// A logger for create TensorRT infer builder.
|
||||||
|
class NaiveLogger : public nvinfer1::ILogger {
|
||||||
|
public:
|
||||||
|
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
|
||||||
|
switch (severity) {
|
||||||
|
case Severity::kINFO:
|
||||||
|
LOG(INFO) << msg;
|
||||||
|
break;
|
||||||
|
case Severity::kWARNING:
|
||||||
|
LOG(WARNING) << msg;
|
||||||
|
break;
|
||||||
|
case Severity::kINTERNAL_ERROR:
|
||||||
|
case Severity::kERROR:
|
||||||
|
LOG(ERROR) << msg;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static nvinfer1::ILogger& Global() {
|
||||||
|
static nvinfer1::ILogger* x = new NaiveLogger;
|
||||||
|
return *x;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~NaiveLogger() override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include <glog/logging.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
class TensorRTEngineTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
ASSERT_EQ(0, cudaStreamCreate(&stream_));
|
||||||
|
engine_ = new TensorRTEngine(1, 1 << 10, &stream_);
|
||||||
|
engine_->InitNetwork();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
delete engine_;
|
||||||
|
cudaStreamDestroy(stream_);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TensorRTEngine* engine_;
|
||||||
|
cudaStream_t stream_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TensorRTEngineTest, add_layer) {
|
||||||
|
const int size = 1;
|
||||||
|
|
||||||
|
float raw_weight[size] = {2.}; // Weight in CPU memory.
|
||||||
|
float raw_bias[size] = {3.};
|
||||||
|
|
||||||
|
LOG(INFO) << "create weights";
|
||||||
|
TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, size);
|
||||||
|
TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, size);
|
||||||
|
auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT,
|
||||||
|
nvinfer1::DimsCHW{1, 1, 1});
|
||||||
|
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, size,
|
||||||
|
weight.get(), bias.get());
|
||||||
|
PADDLE_ENFORCE(fc_layer != nullptr);
|
||||||
|
|
||||||
|
engine_->DeclareOutput(fc_layer, 0, "y");
|
||||||
|
LOG(INFO) << "freeze network";
|
||||||
|
engine_->FreezeNetwork();
|
||||||
|
ASSERT_EQ(engine_->engine()->getNbBindings(), 2);
|
||||||
|
|
||||||
|
// fill in real data
|
||||||
|
float x_v = 1234;
|
||||||
|
engine_->SetInputFromCPU("x", (void*)&x_v, 1 * sizeof(float));
|
||||||
|
LOG(INFO) << "to execute";
|
||||||
|
engine_->Execute(1);
|
||||||
|
|
||||||
|
LOG(INFO) << "to get output";
|
||||||
|
// void* y_v =
|
||||||
|
float y_cpu;
|
||||||
|
engine_->GetOutputInCPU("y", &y_cpu, sizeof(float));
|
||||||
|
|
||||||
|
LOG(INFO) << "to checkout output";
|
||||||
|
ASSERT_EQ(y_cpu, x_v * 2 + 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
Loading…
Reference in new issue