Merge pull request #14440 from hjchen2/develop
Add PRelu tensorRT plugin and Conv2d transpose op converterpanyx0718-patch-1
commit
2f27c048cc
@ -0,0 +1,80 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* PRelu converter from fluid to tensorRT.
|
||||||
|
*/
|
||||||
|
class PReluOpConverter : public OpConverter {
|
||||||
|
public:
|
||||||
|
void operator()(const framework::proto::OpDesc& op,
|
||||||
|
const framework::Scope& scope, bool test_mode) override {
|
||||||
|
VLOG(4) << "convert fluid prelu op to tensorrt prelu layer";
|
||||||
|
|
||||||
|
framework::OpDesc op_desc(op, nullptr);
|
||||||
|
// Declare inputs
|
||||||
|
int input_num = op_desc.Input("X").size();
|
||||||
|
PADDLE_ENFORCE(input_num == 1);
|
||||||
|
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||||
|
// Get output
|
||||||
|
size_t output_num = op_desc.Output("Out").size();
|
||||||
|
PADDLE_ENFORCE(output_num == 1);
|
||||||
|
// Get attrs
|
||||||
|
std::string mode = boost::get<std::string>(op_desc.GetAttr("mode"));
|
||||||
|
//
|
||||||
|
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(alpha_var);
|
||||||
|
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
|
||||||
|
|
||||||
|
platform::CUDAPlace place;
|
||||||
|
std::unique_ptr<framework::LoDTensor> alpha_tensor_device(
|
||||||
|
new framework::LoDTensor());
|
||||||
|
alpha_tensor_device->Resize(alpha_tensor->dims());
|
||||||
|
TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get());
|
||||||
|
float* alpha_data = alpha_tensor_device->mutable_data<float>(place);
|
||||||
|
|
||||||
|
// Transform alpha to TensorRTEngine::Weight
|
||||||
|
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
|
||||||
|
static_cast<void*>(alpha_data),
|
||||||
|
alpha_tensor_device->numel());
|
||||||
|
PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode);
|
||||||
|
nvinfer1::IPluginLayer* layer =
|
||||||
|
engine_->AddPlugin(&input, input_num, plugin);
|
||||||
|
// keep alpha tensor to avoid release it's memory
|
||||||
|
engine_->weight_map[op_desc.Input("Alpha")[0]] =
|
||||||
|
std::move(alpha_tensor_device);
|
||||||
|
|
||||||
|
std::string layer_name = "prelu (Output: ";
|
||||||
|
auto output_name = op_desc.Output("Out")[0];
|
||||||
|
layer->getOutput(0)->setName(output_name.c_str());
|
||||||
|
engine_->SetITensor(output_name, layer->getOutput(0));
|
||||||
|
layer_name += output_name;
|
||||||
|
if (test_mode) {
|
||||||
|
engine_->DeclareOutput(output_name);
|
||||||
|
}
|
||||||
|
layer->setName((layer_name + ")").c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_TRT_OP_CONVERTER(prelu, PReluOpConverter);
|
@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
TEST(prelu_op, test_channel_wise) {
|
||||||
|
std::unordered_set<std::string> parameters({"prelu_alpha"});
|
||||||
|
framework::Scope scope;
|
||||||
|
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||||
|
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
|
||||||
|
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1));
|
||||||
|
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
|
||||||
|
|
||||||
|
// Prepare Op description
|
||||||
|
framework::OpDesc desc;
|
||||||
|
desc.SetType("prelu");
|
||||||
|
desc.SetInput("X", {"prelu_input"});
|
||||||
|
desc.SetInput("Alpha", {"prelu_alpha"});
|
||||||
|
desc.SetOutput("Out", {"prelu_out"});
|
||||||
|
|
||||||
|
desc.SetAttr("mode", std::string("channel"));
|
||||||
|
|
||||||
|
validator.SetOp(*desc.Proto());
|
||||||
|
|
||||||
|
validator.Execute(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(prelu_op, test_element_wise) {
|
||||||
|
std::unordered_set<std::string> parameters({"prelu_alpha"});
|
||||||
|
framework::Scope scope;
|
||||||
|
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||||
|
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
|
||||||
|
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2));
|
||||||
|
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
|
||||||
|
|
||||||
|
// Prepare Op description
|
||||||
|
framework::OpDesc desc;
|
||||||
|
desc.SetType("prelu");
|
||||||
|
desc.SetInput("X", {"prelu_input"});
|
||||||
|
desc.SetInput("Alpha", {"prelu_alpha"});
|
||||||
|
desc.SetOutput("Out", {"prelu_out"});
|
||||||
|
|
||||||
|
desc.SetAttr("mode", std::string("element"));
|
||||||
|
|
||||||
|
validator.SetOp(*desc.Proto());
|
||||||
|
|
||||||
|
validator.Execute(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(prelu_op, test_scalar) {
|
||||||
|
std::unordered_set<std::string> parameters({"prelu_alpha"});
|
||||||
|
framework::Scope scope;
|
||||||
|
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||||
|
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
|
||||||
|
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1));
|
||||||
|
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
|
||||||
|
|
||||||
|
// Prepare Op description
|
||||||
|
framework::OpDesc desc;
|
||||||
|
desc.SetType("prelu");
|
||||||
|
desc.SetInput("X", {"prelu_input"});
|
||||||
|
desc.SetInput("Alpha", {"prelu_alpha"});
|
||||||
|
desc.SetOutput("Out", {"prelu_out"});
|
||||||
|
|
||||||
|
desc.SetAttr("mode", std::string("all"));
|
||||||
|
|
||||||
|
validator.SetOp(*desc.Proto());
|
||||||
|
|
||||||
|
validator.Execute(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
// USE_OP(prelu);
|
||||||
|
USE_CPU_ONLY_OP(prelu);
|
@ -1 +1 @@
|
|||||||
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce)
|
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu prelu_op_plugin.cu DEPS enforce)
|
||||||
|
@ -0,0 +1,131 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <cassert>
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
static const int CUDA_NUM_THREADS = 1024;
|
||||||
|
static const int CUDA_MAX_NUM_BLOCKS = 65535;
|
||||||
|
inline static int GET_NUM_BLOCKS(const int N) {
|
||||||
|
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void PReluChannelWiseKernel(const float *input, const float *alpha,
|
||||||
|
float *output, int channel,
|
||||||
|
size_t spatial_size) {
|
||||||
|
size_t offset = blockIdx.x * spatial_size;
|
||||||
|
const float *in = input + offset;
|
||||||
|
float *out = output + offset;
|
||||||
|
float scale = alpha[blockIdx.x % channel];
|
||||||
|
|
||||||
|
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
|
||||||
|
float x = in[i];
|
||||||
|
out[i] = (x > 0) ? x : scale * x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void PReluElementWiseKernel(const float *input, const float *alpha,
|
||||||
|
float *output, size_t spatial_size) {
|
||||||
|
size_t offset = blockIdx.x * spatial_size;
|
||||||
|
const float *in = input + offset;
|
||||||
|
const float *scale = alpha + offset;
|
||||||
|
float *out = output + offset;
|
||||||
|
|
||||||
|
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
|
||||||
|
float x = in[i];
|
||||||
|
out[i] = (x > 0) ? x : scale[i] * x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void PReluScalarKernel(const float *input, const float *alpha,
|
||||||
|
float *output, size_t spatial_size) {
|
||||||
|
size_t offset = blockIdx.x * spatial_size;
|
||||||
|
const float *in = input + offset;
|
||||||
|
float scale = *alpha;
|
||||||
|
float *out = output + offset;
|
||||||
|
|
||||||
|
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
|
||||||
|
float x = in[i];
|
||||||
|
out[i] = (x > 0) ? x : scale * x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void PReluChannelWise(cudaStream_t stream, const float *input,
|
||||||
|
const float *alpha, float *output,
|
||||||
|
int batch_size,
|
||||||
|
const nvinfer1::Dims &dims) {
|
||||||
|
size_t unroll = batch_size * dims.d[0];
|
||||||
|
size_t spatial_size = dims.d[1] * dims.d[2];
|
||||||
|
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
|
||||||
|
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
|
||||||
|
input, alpha, output, dims.d[0], spatial_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void PReluElementWise(cudaStream_t stream, const float *input,
|
||||||
|
const float *alpha, float *output,
|
||||||
|
int batch_size,
|
||||||
|
const nvinfer1::Dims &dims) {
|
||||||
|
size_t unroll = batch_size * dims.d[0];
|
||||||
|
size_t spatial_size = dims.d[1] * dims.d[2];
|
||||||
|
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
|
||||||
|
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
|
||||||
|
input, alpha, output, spatial_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void PReluScalar(cudaStream_t stream, const float *input,
|
||||||
|
const float *alpha, float *output,
|
||||||
|
int batch_size, const nvinfer1::Dims &dims) {
|
||||||
|
size_t unroll = batch_size * dims.d[0];
|
||||||
|
size_t spatial_size = dims.d[1] * dims.d[2];
|
||||||
|
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
|
||||||
|
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
|
||||||
|
input, alpha, output, spatial_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
|
||||||
|
const nvinfer1::Dims *inputDims,
|
||||||
|
int nbInputs) {
|
||||||
|
assert(nbInputs == 1);
|
||||||
|
assert(index < this->getNbOutputs());
|
||||||
|
nvinfer1::Dims const &input_dims = inputDims[0];
|
||||||
|
nvinfer1::Dims output_dims = input_dims;
|
||||||
|
return output_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
|
||||||
|
void **outputs, void *workspace, cudaStream_t stream) {
|
||||||
|
// input dims is CHW.
|
||||||
|
const auto &input_dims = this->getInputDims(0);
|
||||||
|
const float *input = reinterpret_cast<const float *>(inputs[0]);
|
||||||
|
const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
|
||||||
|
float *output = reinterpret_cast<float **>(outputs)[0];
|
||||||
|
if (mode_ == "channel") {
|
||||||
|
PReluChannelWise(stream, input, alpha, output, batchSize, input_dims);
|
||||||
|
} else if (mode_ == "element") {
|
||||||
|
PReluElementWise(stream, input, alpha, output, batchSize, input_dims);
|
||||||
|
} else {
|
||||||
|
PReluScalar(stream, input, alpha, output, batchSize, input_dims);
|
||||||
|
}
|
||||||
|
return cudaGetLastError() != cudaSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,68 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
class PReluPlugin : public PluginTensorRT {
|
||||||
|
TensorRTEngine::Weight alpha_;
|
||||||
|
std::string mode_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
size_t getSerializationSize() override {
|
||||||
|
// return getBaseSerializationSize(alpha_) + SerializedSize(mode_);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TRT will call this func when we need to serialize the configuration of
|
||||||
|
// tensorrt.
|
||||||
|
// It should not be called by users.
|
||||||
|
void serialize(void *buffer) override {
|
||||||
|
// serializeBase(buffer);
|
||||||
|
// SerializeValue(&buffer, alpha_);
|
||||||
|
// SerializeValue(&buffer, mode_);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode)
|
||||||
|
: alpha_(alpha), mode_(mode) {}
|
||||||
|
|
||||||
|
// It was used for tensorrt deserialization.
|
||||||
|
// It should not be called by users.
|
||||||
|
PReluPlugin(void const *serialData, size_t serialLength) {
|
||||||
|
// deserializeBase(serialData, serialLength);
|
||||||
|
// DeserializeValue(&serialData, &serialLength, &alpha_);
|
||||||
|
// DeserializeValue(&serialData, &serialLength, &mode_);
|
||||||
|
}
|
||||||
|
|
||||||
|
PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); }
|
||||||
|
|
||||||
|
const char *getPluginType() const override { return "prelu"; }
|
||||||
|
int getNbOutputs() const override { return 1; }
|
||||||
|
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||||
|
int nbInputDims) override;
|
||||||
|
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||||
|
void *workspace, cudaStream_t stream) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
Loading…
Reference in new issue