Update trt5 for paddle-trt (#18645)
* update paddle-trt for: 1. fix bug: when batch > 2, core in split plugin. 2. add leaky_relu trt5.0 support (yolov3 from 65ms to 42ms.) 3. add new attr to dropout. 4. shuffle channel, swish, relu6 support test=develop * 1. fix ci test=developDDDivano-patch-1
parent
d8396281ef
commit
26ae6d49e4
@ -0,0 +1,57 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
/*
|
||||
* ConcatOp
|
||||
*/
|
||||
class ShuffleChannelOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
// Declare inputs
|
||||
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||
auto input_dims = input->getDimensions();
|
||||
PADDLE_ENFORCE(input_dims.nbDims == 3);
|
||||
int c = input_dims.d[0];
|
||||
int h = input_dims.d[1];
|
||||
int w = input_dims.d[2];
|
||||
int group = boost::get<int>(op_desc.GetAttr("group"));
|
||||
|
||||
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
|
||||
nvinfer1::Dims4 reshape_dim(group, c / group, h, w);
|
||||
layer->setReshapeDimensions(reshape_dim);
|
||||
layer->setSecondTranspose({1, 0, 2, 3});
|
||||
auto* output = layer->getOutput(0);
|
||||
|
||||
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *output);
|
||||
nvinfer1::DimsCHW reshape_dim2(c, h, w);
|
||||
reshape_layer->setReshapeDimensions(reshape_dim2);
|
||||
|
||||
auto output_name = op_desc.Output("Out")[0];
|
||||
RreplenishLayerAndOutput(reshape_layer, "concat", {output_name}, test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(shuffle_channel, ShuffleChannelOpConverter);
|
@ -0,0 +1,53 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class SwishOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(4) << "convert fluid swish op to tensorrt layer";
|
||||
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
// Declare inputs
|
||||
int input_num = op_desc.Input("X").size();
|
||||
PADDLE_ENFORCE(input_num == 1);
|
||||
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||
// Get output
|
||||
size_t output_num = op_desc.Output("Out").size();
|
||||
PADDLE_ENFORCE(output_num == 1);
|
||||
// Get attrs
|
||||
float beta = boost::get<float>(op_desc.GetAttr("beta"));
|
||||
|
||||
plugin::SwishPlugin* plugin = new plugin::SwishPlugin(beta);
|
||||
|
||||
nvinfer1::IPluginLayer* layer =
|
||||
engine_->AddPlugin(&input, input_num, plugin);
|
||||
|
||||
auto output_name = op_desc.Output("Out")[0];
|
||||
RreplenishLayerAndOutput(layer, "swish", {output_name}, test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(swish, SwishOpConverter);
|
@ -0,0 +1,48 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
TEST(leaky_relu_op, test_leaky_relu) {
|
||||
std::unordered_set<std::string> parameters;
|
||||
framework::Scope scope;
|
||||
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||
validator.DeclInputVar("sc_input", nvinfer1::DimsCHW(4, 2, 2));
|
||||
validator.DeclOutputVar("sc_out", nvinfer1::DimsCHW(4, 2, 2));
|
||||
|
||||
// Prepare Op description
|
||||
framework::OpDesc desc;
|
||||
desc.SetType("shuffle_channel");
|
||||
desc.SetInput("X", {"sc_input"});
|
||||
desc.SetOutput("Out", {"sc_out"});
|
||||
int group = 2;
|
||||
desc.SetAttr("group", group);
|
||||
|
||||
validator.SetOp(*desc.Proto());
|
||||
|
||||
validator.Execute(1);
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
// USE_OP(leaky_relu);
|
||||
USE_OP(shuffle_channel);
|
@ -0,0 +1,47 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
TEST(swish_op, test_swish) {
|
||||
std::unordered_set<std::string> parameters;
|
||||
framework::Scope scope;
|
||||
TRTConvertValidation validator(10, parameters, scope, 1000);
|
||||
validator.DeclInputVar("sw_input", nvinfer1::DimsCHW(3, 2, 2));
|
||||
validator.DeclOutputVar("sw_out", nvinfer1::DimsCHW(3, 2, 2));
|
||||
|
||||
// Prepare Op description
|
||||
framework::OpDesc desc;
|
||||
desc.SetType("swish");
|
||||
desc.SetInput("X", {"sw_input"});
|
||||
desc.SetOutput("Out", {"sw_out"});
|
||||
|
||||
desc.SetAttr("beta", 2.0f);
|
||||
|
||||
validator.SetOp(*desc.Proto());
|
||||
|
||||
validator.Execute(1);
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
USE_OP(swish);
|
@ -1,5 +1,5 @@
|
||||
nv_library(tensorrt_plugin
|
||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
||||
prelu_op_plugin.cu trt_plugin_factory.cc
|
||||
avg_pool_op_plugin.cu
|
||||
avg_pool_op_plugin.cu swish_op_plugin.cu
|
||||
DEPS enforce tensorrt_engine prelu)
|
||||
|
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/swish_op_plugin.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
SwishPlugin *CreateSwishPluginDeserialize(const void *buffer, size_t length) {
|
||||
return new SwishPlugin(buffer, length);
|
||||
}
|
||||
REGISTER_TRT_PLUGIN("swish_plugin", CreateSwishPluginDeserialize);
|
||||
|
||||
int SwishPlugin::initialize() { return 0; }
|
||||
|
||||
nvinfer1::Dims SwishPlugin::getOutputDimensions(int index,
|
||||
const nvinfer1::Dims *inputDims,
|
||||
int nbInputs) {
|
||||
assert(nbInputs == 1);
|
||||
assert(index < this->getNbOutputs());
|
||||
nvinfer1::Dims const &input_dims = inputDims[0];
|
||||
nvinfer1::Dims output_dims = input_dims;
|
||||
return output_dims;
|
||||
}
|
||||
__global__ void swish_kernel(int num, const float *input, float *output,
|
||||
float beta) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < num) {
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
output[index] =
|
||||
__ldg(input + index) / (1.0f + expf(-beta * __ldg(input + index)));
|
||||
#else
|
||||
output[index] = input[index] / (1.0f + expf(-beta * input[index]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
int SwishPlugin::enqueue(int batch_size, const void *const *inputs,
|
||||
void **outputs, void *workspace, cudaStream_t stream) {
|
||||
// input dims is CHW.
|
||||
const auto &input_dims = this->getInputDims(0);
|
||||
const float *input = reinterpret_cast<const float *>(inputs[0]);
|
||||
float *output = reinterpret_cast<float **>(outputs)[0];
|
||||
int num = batch_size;
|
||||
for (int i = 0; i < input_dims.nbDims; i++) {
|
||||
num *= input_dims.d[i];
|
||||
}
|
||||
int threads = 1024;
|
||||
int blocks = (num + threads - 1) / threads;
|
||||
swish_kernel<<<blocks, threads, 0, stream>>>(num, input, output, beta_);
|
||||
|
||||
return cudaGetLastError() != cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
class SwishPlugin : public PluginTensorRT {
|
||||
private:
|
||||
float beta_;
|
||||
|
||||
protected:
|
||||
size_t getSerializationSize() override {
|
||||
return getBaseSerializationSize() + SerializedSize(beta_);
|
||||
}
|
||||
|
||||
// TRT will call this func when we need to serialize the configuration of
|
||||
// tensorrt.
|
||||
// It should not be called by users.
|
||||
void serialize(void *buffer) override {
|
||||
SerializeValue(&buffer, getPluginType());
|
||||
serializeBase(buffer);
|
||||
SerializeValue(&buffer, beta_);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit SwishPlugin(const float beta) : beta_(beta) {}
|
||||
|
||||
// It was used for tensorrt deserialization.
|
||||
// It should not be called by users.
|
||||
SwishPlugin(void const *serialData, size_t serialLength) {
|
||||
deserializeBase(serialData, serialLength);
|
||||
DeserializeValue(&serialData, &serialLength, &beta_);
|
||||
}
|
||||
~SwishPlugin() {}
|
||||
int initialize() override;
|
||||
|
||||
SwishPlugin *clone() const override { return new SwishPlugin(beta_); }
|
||||
|
||||
const char *getPluginType() const override { return "swish_plugin"; }
|
||||
int getNbOutputs() const override { return 1; }
|
||||
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||
int nbInputDims) override;
|
||||
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||
void *workspace, cudaStream_t stream) override;
|
||||
};
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue