|
|
|
@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <cuda.h>
|
|
|
|
|
#include <cuda_runtime_api.h>
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
#include "NvInfer.h"
|
|
|
|
|
#include "cuda.h"
|
|
|
|
|
#include "cuda_runtime_api.h"
|
|
|
|
|
#include "paddle/fluid/platform/dynload/tensorrt.h"
|
|
|
|
|
|
|
|
|
|
namespace dy = paddle::platform::dynload;
|
|
|
|
@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger {
|
|
|
|
|
|
|
|
|
|
class ScopedWeights {
|
|
|
|
|
public:
|
|
|
|
|
ScopedWeights(float value) : value_(value) {
|
|
|
|
|
explicit ScopedWeights(float value) : value_(value) {
|
|
|
|
|
w.type = nvinfer1::DataType::kFLOAT;
|
|
|
|
|
w.values = &value_;
|
|
|
|
|
w.count = 1;
|
|
|
|
@ -58,13 +58,13 @@ class ScopedWeights {
|
|
|
|
|
// The following two API are implemented in TensorRT's header file, cannot load
|
|
|
|
|
// from the dynamic library. So create our own implementation and directly
|
|
|
|
|
// trigger the method from the dynamic library.
|
|
|
|
|
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
|
|
|
|
|
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
|
|
|
|
|
return static_cast<nvinfer1::IBuilder*>(
|
|
|
|
|
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
|
|
|
|
|
dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
|
|
|
|
|
}
|
|
|
|
|
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
|
|
|
|
|
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
|
|
|
|
|
return static_cast<nvinfer1::IRuntime*>(
|
|
|
|
|
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
|
|
|
|
|
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const char* kInputTensor = "input";
|
|
|
|
@ -74,7 +74,7 @@ const char* kOutputTensor = "output";
|
|
|
|
|
nvinfer1::IHostMemory* CreateNetwork() {
|
|
|
|
|
Logger logger;
|
|
|
|
|
// Create the engine.
|
|
|
|
|
nvinfer1::IBuilder* builder = createInferBuilder(logger);
|
|
|
|
|
nvinfer1::IBuilder* builder = createInferBuilder(&logger);
|
|
|
|
|
ScopedWeights weights(2.);
|
|
|
|
|
ScopedWeights bias(3.);
|
|
|
|
|
|
|
|
|
@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
|
|
|
|
|
return model;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Execute(nvinfer1::IExecutionContext& context, const float* input,
|
|
|
|
|
void Execute(nvinfer1::IExecutionContext* context, const float* input,
|
|
|
|
|
float* output) {
|
|
|
|
|
const nvinfer1::ICudaEngine& engine = context.getEngine();
|
|
|
|
|
const nvinfer1::ICudaEngine& engine = context->getEngine();
|
|
|
|
|
// Two binds, input and output
|
|
|
|
|
ASSERT_EQ(engine.getNbBindings(), 2);
|
|
|
|
|
const int input_index = engine.getBindingIndex(kInputTensor);
|
|
|
|
@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
|
|
|
|
|
// Copy the input to the GPU, execute the network, and copy the output back.
|
|
|
|
|
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
|
|
|
|
|
cudaMemcpyHostToDevice, stream));
|
|
|
|
|
context.enqueue(1, buffers, stream, nullptr);
|
|
|
|
|
context->enqueue(1, buffers, stream, nullptr);
|
|
|
|
|
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
|
|
|
|
|
cudaMemcpyDeviceToHost, stream));
|
|
|
|
|
cudaStreamSynchronize(stream);
|
|
|
|
@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) {
|
|
|
|
|
|
|
|
|
|
// Use the model to create an engine and an execution context.
|
|
|
|
|
Logger logger;
|
|
|
|
|
nvinfer1::IRuntime* runtime = createInferRuntime(logger);
|
|
|
|
|
nvinfer1::IRuntime* runtime = createInferRuntime(&logger);
|
|
|
|
|
nvinfer1::ICudaEngine* engine =
|
|
|
|
|
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
|
|
|
|
|
model->destroy();
|
|
|
|
@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) {
|
|
|
|
|
// Execute the network.
|
|
|
|
|
float input = 1234;
|
|
|
|
|
float output;
|
|
|
|
|
Execute(*context, &input, &output);
|
|
|
|
|
Execute(context, &input, &output);
|
|
|
|
|
EXPECT_EQ(output, input * 2 + 3);
|
|
|
|
|
|
|
|
|
|
// Destroy the engine.
|
|
|
|
|