Add TRT support for BERT (#21135)
* add gelu plugin * align trt bert with gpu * add support for fused fc with relu, * add unittest for bert trtrelease/1.7
parent
b0b27ff699
commit
0a51098a71
@ -0,0 +1,61 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
/*
|
||||
* Gelu converter from fluid to tensorRT.
|
||||
*/
|
||||
class GeluOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(4) << "convert fluid gelu op to tensorrt gelu layer";
|
||||
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
// Declare inputs
|
||||
int input_num = op_desc.Input("X").size();
|
||||
PADDLE_ENFORCE_EQ(input_num, 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"gelu op has only 1 input, but got %d", input_num));
|
||||
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||
// Get output
|
||||
size_t output_num = op_desc.Output("Out").size();
|
||||
PADDLE_ENFORCE_EQ(output_num, 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"gelu op has only 1 output, but got %d", output_num));
|
||||
// Get input shape and volume
|
||||
nvinfer1::Dims input_shape = input->getDimensions();
|
||||
size_t input_volume = 1;
|
||||
for (int i = 0; i < input_shape.nbDims; i++) {
|
||||
input_volume *= input_shape.d[i];
|
||||
}
|
||||
plugin::GeluPlugin* plugin = new plugin::GeluPlugin(input_volume);
|
||||
nvinfer1::IPluginLayer* layer =
|
||||
engine_->AddPlugin(&input, input_num, plugin);
|
||||
auto output_name = op_desc.Output("Out")[0];
|
||||
RreplenishLayerAndOutput(layer, "gelu", {output_name}, test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(gelu, GeluOpConverter);
|
@ -0,0 +1,108 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include "paddle/fluid/operators/layer_norm_op.h"
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class LayerNormOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(4) << "convert a fluid layer_norm op to tensorrt layer_norm plugin";
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
op_desc.Input("X").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"input of layer_norm op converter should be 1, got %d",
|
||||
op_desc.Input("X").size()));
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"Bias of layer_norm op converter should be 1, got %d",
|
||||
op_desc.Input("Bias").size())); // Bias is a weight
|
||||
PADDLE_ENFORCE_EQ(
|
||||
op_desc.Input("Scale").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"Scale of layer_norm op converter should be 1, got %d",
|
||||
op_desc.Input("Scale").size())); // Scale is a weight
|
||||
PADDLE_ENFORCE_EQ(
|
||||
op_desc.Output("Y").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"output of layer_norm op converter should be 1, got %d",
|
||||
op_desc.Input("Y").size()));
|
||||
|
||||
auto* X = engine_->GetITensor(op_desc.Input("X").front());
|
||||
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
|
||||
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
|
||||
const int begin_norm_axis =
|
||||
op_desc.HasAttr("begin_norm_axis")
|
||||
? boost::get<int>(op_desc.GetAttr("begin_norm_axis"))
|
||||
: 1;
|
||||
const float eps = op_desc.HasAttr("epsilon")
|
||||
? boost::get<float>(op_desc.GetAttr("epsilon"))
|
||||
: 1e-5f;
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
Bias_v, platform::errors::InvalidArgument(
|
||||
"Input(Bias) of layer_norm should not be null."));
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
Scale_v, platform::errors::InvalidArgument(
|
||||
"Input(Scale) of layer_norm should not be null."));
|
||||
|
||||
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
|
||||
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
|
||||
|
||||
int input_num = 1;
|
||||
for (int i = 0; i < X->getDimensions().nbDims; i++) {
|
||||
input_num *= X->getDimensions().d[i];
|
||||
}
|
||||
std::vector<int64_t> mean_shape{input_num};
|
||||
std::vector<int64_t> variance_shape{input_num};
|
||||
|
||||
std::unique_ptr<framework::LoDTensor> bias_tensor(
|
||||
new framework::LoDTensor());
|
||||
std::unique_ptr<framework::LoDTensor> scale_tensor(
|
||||
new framework::LoDTensor());
|
||||
|
||||
bias_tensor->Resize(Bias_t->dims());
|
||||
scale_tensor->Resize(Scale_t->dims());
|
||||
|
||||
platform::CPUPlace cpu_place;
|
||||
TensorCopySync((*Bias_t), cpu_place, &(*bias_tensor));
|
||||
TensorCopySync((*Scale_t), cpu_place, &(*scale_tensor));
|
||||
|
||||
auto* bias_data = bias_tensor->mutable_data<float>(platform::CPUPlace());
|
||||
auto* scale_data = scale_tensor->mutable_data<float>(platform::CPUPlace());
|
||||
|
||||
plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin(
|
||||
bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(),
|
||||
begin_norm_axis, eps, mean_shape, variance_shape);
|
||||
nvinfer1::IPluginLayer* layernorm_layer = engine_->AddPlugin(&X, 1, plugin);
|
||||
|
||||
auto output_name = op_desc.Output("Y").front();
|
||||
engine_->SetWeights(op_desc.Input("Bias").front(), std::move(bias_tensor));
|
||||
engine_->SetWeights(op_desc.Input("Scale").front(),
|
||||
std::move(scale_tensor));
|
||||
RreplenishLayerAndOutput(layernorm_layer, "layer_norm", {output_name},
|
||||
test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(layer_norm, LayerNormOpConverter);
|
@ -0,0 +1,213 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class MultiheadMatMulOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt "
|
||||
"network structure";
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
// Declare inputs
|
||||
auto* Q = engine_->GetITensor(op_desc.Input("Q").front());
|
||||
auto* K = engine_->GetITensor(op_desc.Input("K").front());
|
||||
auto* V = engine_->GetITensor(op_desc.Input("V").front());
|
||||
auto* BiasQ = scope.FindVar(op_desc.Input("BiasQ").front());
|
||||
auto* BiasK = scope.FindVar(op_desc.Input("BiasK").front());
|
||||
auto* BiasV = scope.FindVar(op_desc.Input("BiasV").front());
|
||||
auto* BiasQK = engine_->GetITensor(op_desc.Input("BiasQK").front());
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("Q").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"size of input Q of multihead_matmul should be 1"));
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("K").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"size of input K of multihead_matmul should be 1"));
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("V").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"size of input V of multihead_matmul should be 1"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
op_desc.Input("BiasQK").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"size of input BiasQK of multihead_matmul should be 1"));
|
||||
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"size of output of multihead_matmul should be 1"));
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
BiasQ, platform::errors::InvalidArgument(
|
||||
"param BiasQ of multihead_matmul should not be null"));
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
BiasK, platform::errors::InvalidArgument(
|
||||
"param BiasK of multihead_matmul should not be null"));
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
BiasV, platform::errors::InvalidArgument(
|
||||
"param BiasV of multihead_matmul should not be null"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
BiasQK->getDimensions().nbDims, 3,
|
||||
platform::errors::InvalidArgument(
|
||||
"dims size of input BiasQK of multihead_matmul should be 3"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
op_desc.HasAttr("alpha"), true,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"attribute alpha of multihead_matmul should not be empty"));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
op_desc.HasAttr("head_number"), true,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"attribute head_number of multihead_matmul should not be empty"));
|
||||
|
||||
// Declare attributes
|
||||
const bool transpose_q =
|
||||
op_desc.HasAttr("transpose_Q")
|
||||
? boost::get<bool>(op_desc.GetAttr("transpose_Q"))
|
||||
: false;
|
||||
const bool transpose_k =
|
||||
op_desc.HasAttr("transpose_K")
|
||||
? boost::get<bool>(op_desc.GetAttr("transpose_K"))
|
||||
: true;
|
||||
const bool transpose_v =
|
||||
op_desc.HasAttr("transpose_V")
|
||||
? boost::get<bool>(op_desc.GetAttr("transpose_V"))
|
||||
: false;
|
||||
const float alpha = boost::get<float>(op_desc.GetAttr("alpha"));
|
||||
const int head_number = boost::get<int>(op_desc.GetAttr("head_number"));
|
||||
|
||||
nvinfer1::Dims q_shape = Q->getDimensions();
|
||||
int seq_len = q_shape.d[0];
|
||||
int size_per_head = q_shape.d[1] / head_number;
|
||||
std::string alpha_name = op_desc.Output("Out")[0] + "_alpha";
|
||||
framework::DDim alpha_dim = framework::make_ddim({1});
|
||||
std::unique_ptr<framework::LoDTensor> alpha_t(new framework::LoDTensor());
|
||||
alpha_t->Resize(alpha_dim);
|
||||
float* alpha_data = alpha_t->mutable_data<float>(platform::CPUPlace());
|
||||
alpha_data[0] = alpha;
|
||||
|
||||
TensorRTEngine::Weight scale{nvinfer1::DataType::kFLOAT,
|
||||
static_cast<void*>(alpha_data), 1};
|
||||
TensorRTEngine::Weight shift{nvinfer1::DataType::kFLOAT, nullptr, 0};
|
||||
TensorRTEngine::Weight power{nvinfer1::DataType::kFLOAT, nullptr, 0};
|
||||
|
||||
auto* bias_q_t = BiasQ->GetMutable<framework::LoDTensor>();
|
||||
auto* bias_k_t = BiasK->GetMutable<framework::LoDTensor>();
|
||||
auto* bias_v_t = BiasV->GetMutable<framework::LoDTensor>();
|
||||
float* bias_q_cpu_data = engine_->GetWeightCPUData(
|
||||
op_desc.Input("BiasQ").front(), bias_q_t, false);
|
||||
float* bias_k_cpu_data = engine_->GetWeightCPUData(
|
||||
op_desc.Input("BiasK").front(), bias_k_t, false);
|
||||
float* bias_v_cpu_data = engine_->GetWeightCPUData(
|
||||
op_desc.Input("BiasV").front(), bias_v_t, false);
|
||||
std::unique_ptr<framework::LoDTensor> bias_q_tensor(
|
||||
new framework::LoDTensor());
|
||||
std::unique_ptr<framework::LoDTensor> bias_k_tensor(
|
||||
new framework::LoDTensor());
|
||||
std::unique_ptr<framework::LoDTensor> bias_v_tensor(
|
||||
new framework::LoDTensor());
|
||||
bias_q_tensor->Resize(bias_q_t->dims());
|
||||
bias_k_tensor->Resize(bias_k_t->dims());
|
||||
bias_v_tensor->Resize(bias_v_t->dims());
|
||||
platform::CPUPlace cpu_place;
|
||||
TensorCopySync((*bias_q_t), cpu_place, bias_q_tensor.get());
|
||||
TensorCopySync((*bias_k_t), cpu_place, bias_k_tensor.get());
|
||||
TensorCopySync((*bias_v_t), cpu_place, bias_v_tensor.get());
|
||||
|
||||
TensorRTEngine::Weight scale_weights_q{nvinfer1::DataType::kFLOAT, nullptr,
|
||||
0};
|
||||
TensorRTEngine::Weight shift_weights_q{
|
||||
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_q_cpu_data),
|
||||
bias_q_tensor->memory_size() / sizeof(float)};
|
||||
TensorRTEngine::Weight power_weights_q{nvinfer1::DataType::kFLOAT, nullptr,
|
||||
0};
|
||||
TensorRTEngine::Weight scale_weights_k{nvinfer1::DataType::kFLOAT, nullptr,
|
||||
0};
|
||||
TensorRTEngine::Weight shift_weights_k{
|
||||
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_k_cpu_data),
|
||||
bias_k_tensor->memory_size() / sizeof(float)};
|
||||
TensorRTEngine::Weight power_weights_k{nvinfer1::DataType::kFLOAT, nullptr,
|
||||
0};
|
||||
TensorRTEngine::Weight scale_weights_v{nvinfer1::DataType::kFLOAT, nullptr,
|
||||
0};
|
||||
TensorRTEngine::Weight shift_weights_v{
|
||||
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_v_cpu_data),
|
||||
bias_v_tensor->memory_size() / sizeof(float)};
|
||||
TensorRTEngine::Weight power_weights_v{nvinfer1::DataType::kFLOAT, nullptr,
|
||||
0};
|
||||
|
||||
auto* q_eltadd_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, Scale, *Q, nvinfer1::ScaleMode::kCHANNEL,
|
||||
shift_weights_q.get(), scale_weights_q.get(), power_weights_q.get());
|
||||
auto* k_eltadd_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, Scale, *K, nvinfer1::ScaleMode::kCHANNEL,
|
||||
shift_weights_k.get(), scale_weights_k.get(), power_weights_k.get());
|
||||
auto* v_eltadd_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, Scale, *V, nvinfer1::ScaleMode::kCHANNEL,
|
||||
shift_weights_v.get(), scale_weights_v.get(), power_weights_v.get());
|
||||
auto* v_transpose_reshape_layer =
|
||||
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(v_eltadd_layer->getOutput(0)));
|
||||
auto* q_transpose_reshape_layer =
|
||||
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(q_eltadd_layer->getOutput(0)));
|
||||
auto* k_transpose_reshape_layer =
|
||||
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(k_eltadd_layer->getOutput(0)));
|
||||
|
||||
nvinfer1::Dims3 head_reshape_dim(seq_len, head_number, size_per_head);
|
||||
v_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
|
||||
v_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
|
||||
q_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
|
||||
q_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
|
||||
k_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
|
||||
k_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
|
||||
|
||||
auto* q_scale_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, Scale, *(q_transpose_reshape_layer->getOutput(0)),
|
||||
nvinfer1::ScaleMode::kUNIFORM, shift.get(), scale.get(), power.get());
|
||||
auto* qk_matmul_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, MatrixMultiply, *(q_scale_layer->getOutput(0)), transpose_q,
|
||||
*(k_transpose_reshape_layer->getOutput(0)), transpose_k);
|
||||
auto* qk_eltadd_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, ElementWise, *BiasQK, *(qk_matmul_layer->getOutput(0)),
|
||||
nvinfer1::ElementWiseOperation::kSUM);
|
||||
auto* softmax_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, SoftMax, *(qk_eltadd_layer->getOutput(0)));
|
||||
softmax_layer->setAxes(4);
|
||||
auto* qkv_matmul_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, MatrixMultiply, *(softmax_layer->getOutput(0)), false,
|
||||
*(v_transpose_reshape_layer->getOutput(0)), transpose_v);
|
||||
auto* qkv_transpose_reshape_layer = TRT_ENGINE_ADD_LAYER(
|
||||
engine_, Shuffle, *(qkv_matmul_layer->getOutput(0)));
|
||||
nvinfer1::Dims2 qkv_reshape_dim(seq_len, head_number * size_per_head);
|
||||
qkv_transpose_reshape_layer->setFirstTranspose({1, 0, 2});
|
||||
qkv_transpose_reshape_layer->setReshapeDimensions(qkv_reshape_dim);
|
||||
|
||||
engine_->SetWeights(alpha_name, std::move(alpha_t));
|
||||
engine_->SetWeights(op_desc.Input("BiasQ").front(),
|
||||
std::move(bias_q_tensor));
|
||||
engine_->SetWeights(op_desc.Input("BiasK").front(),
|
||||
std::move(bias_k_tensor));
|
||||
engine_->SetWeights(op_desc.Input("BiasV").front(),
|
||||
std::move(bias_v_tensor));
|
||||
|
||||
auto output_name = op_desc.Output("Out").front();
|
||||
RreplenishLayerAndOutput(qkv_transpose_reshape_layer, "multihead_matmul",
|
||||
{output_name}, test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(multihead_matmul, MultiheadMatMulOpConverter);
|
@ -1,5 +1,5 @@
|
||||
nv_library(tensorrt_plugin
|
||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
||||
prelu_op_plugin.cu trt_plugin_factory.cc
|
||||
pool_op_plugin.cu swish_op_plugin.cu
|
||||
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
|
||||
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
|
||||
DEPS enforce tensorrt_engine prelu)
|
||||
|
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
// constants for approximating the normal cdf
|
||||
constexpr float A = 1.41421356237309504; // sqrt(2)
|
||||
|
||||
GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) {
|
||||
return new GeluPlugin(buffer, length);
|
||||
}
|
||||
REGISTER_TRT_PLUGIN("gelu plugin", CreateGeluPluginDeserialize);
|
||||
|
||||
nvinfer1::Dims GeluPlugin::getOutputDimensions(int index,
|
||||
const nvinfer1::Dims* in_dims,
|
||||
int nb_inputs) {
|
||||
assert(nb_inputs == 1);
|
||||
assert(index < this->getNbOutputs());
|
||||
nvinfer1::Dims const& input_dims = in_dims[0];
|
||||
nvinfer1::Dims output_dims = input_dims;
|
||||
return output_dims;
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void geluKernel(const T a, int n, const T* input, T* output) {
|
||||
const int idx = blockIdx.x * TPB + threadIdx.x;
|
||||
if (idx < n) {
|
||||
const T in = input[idx];
|
||||
const T cdf = 0.5 * (1.0 + erf(in * 0.5 * a));
|
||||
output[idx] = in * cdf;
|
||||
}
|
||||
}
|
||||
|
||||
int computeGelu(cudaStream_t stream, int n, const float* input, float* output) {
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (n + blockSize - 1) / blockSize;
|
||||
geluKernel<float, blockSize><<<gridSize, blockSize, 0, stream>>>(A, n, input,
|
||||
output);
|
||||
cudaError_t error = cudaGetLastError();
|
||||
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int GeluPlugin::enqueue(int batchSize, const void* const* inputs,
|
||||
void** outputs, void*, cudaStream_t stream) {
|
||||
int status = -1;
|
||||
const float* input = static_cast<const float*>(inputs[0]);
|
||||
float* output = static_cast<float*>(outputs[0]);
|
||||
status = computeGelu(stream, input_volume_ * batchSize, input, output);
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <stdio.h>
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
class GeluPlugin : public PluginTensorRT {
|
||||
protected:
|
||||
size_t getSerializationSize() override {
|
||||
return getBaseSerializationSize() + SerializedSize(getPluginType()) +
|
||||
SerializedSize(input_volume_);
|
||||
}
|
||||
|
||||
// TRT will call this func to serialize the configuration of TRT
|
||||
// It should not be called by users.
|
||||
void serialize(void *buffer) override {
|
||||
SerializeValue(&buffer, getPluginType());
|
||||
serializeBase(buffer);
|
||||
SerializeValue(&buffer, input_volume_);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit GeluPlugin(size_t input_volume) : input_volume_(input_volume) {}
|
||||
|
||||
// It was used for tensorrt deserialization.
|
||||
// It should not be called by users.
|
||||
GeluPlugin(void const *serialData, size_t serialLength) {
|
||||
deserializeBase(serialData, serialLength);
|
||||
DeserializeValue(&serialData, &serialLength, &input_volume_);
|
||||
}
|
||||
|
||||
~GeluPlugin() {}
|
||||
|
||||
int initialize() override { return 0; }
|
||||
|
||||
GeluPlugin *clone() const override { return new GeluPlugin(input_volume_); }
|
||||
|
||||
const char *getPluginType() const override { return "gelu_plugin"; }
|
||||
int getNbOutputs() const override { return 1; }
|
||||
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||
int nbInputDims) override;
|
||||
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||
void *workspace, cudaStream_t stream) override;
|
||||
|
||||
private:
|
||||
size_t input_volume_;
|
||||
};
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,84 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "paddle/fluid/operators/layer_norm_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
LayerNormPlugin *CreateLayerNormPluginDeserialize(const void *buffer,
|
||||
size_t length) {
|
||||
return new LayerNormPlugin(buffer, length);
|
||||
}
|
||||
REGISTER_TRT_PLUGIN("layer_norm_plugin", CreateLayerNormPluginDeserialize);
|
||||
|
||||
int LayerNormPlugin::initialize() { return 0; }
|
||||
|
||||
nvinfer1::Dims LayerNormPlugin::getOutputDimensions(
|
||||
int index, const nvinfer1::Dims *inputDims, int nbInputs) {
|
||||
assert(nbInputs == 1);
|
||||
assert(index < this->getNbOutputs());
|
||||
nvinfer1::Dims const &input_dims = inputDims[0];
|
||||
nvinfer1::Dims output_dims = input_dims;
|
||||
return output_dims;
|
||||
}
|
||||
|
||||
int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
|
||||
void **outputs, void *workspace,
|
||||
cudaStream_t stream) {
|
||||
const auto &input_dims = this->getInputDims(0);
|
||||
const float *input = reinterpret_cast<const float *>(inputs[0]);
|
||||
float *output = reinterpret_cast<float **>(outputs)[0];
|
||||
int begin_norm_axis = begin_norm_axis_;
|
||||
float eps = eps_;
|
||||
int c = input_dims.d[begin_norm_axis - 1];
|
||||
|
||||
scale_t.Resize(framework::make_ddim({c}));
|
||||
bias_t.Resize(framework::make_ddim({c}));
|
||||
mean_t.Resize(framework::make_ddim(mean_shape_));
|
||||
variance_t.Resize(framework::make_ddim(variance_shape_));
|
||||
int device_id;
|
||||
cudaGetDevice(&device_id);
|
||||
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
|
||||
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
|
||||
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
|
||||
float *variance_d =
|
||||
variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
|
||||
cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * c,
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * c,
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
std::vector<int> input_shape;
|
||||
input_shape.push_back(batch_size);
|
||||
for (int i = 0; i < input_dims.nbDims; i++) {
|
||||
input_shape.push_back(input_dims.d[i]);
|
||||
}
|
||||
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
|
||||
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
|
||||
variance_d, begin_norm_axis, eps);
|
||||
return cudaGetLastError() != cudaSuccess;
|
||||
}
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,110 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
class LayerNormPlugin : public PluginTensorRT {
|
||||
std::vector<float> bias_;
|
||||
std::vector<float> scale_;
|
||||
framework::Tensor scale_t;
|
||||
framework::Tensor bias_t;
|
||||
framework::Tensor mean_t;
|
||||
framework::Tensor variance_t;
|
||||
int begin_norm_axis_;
|
||||
float eps_;
|
||||
std::vector<int64_t> mean_shape_;
|
||||
std::vector<int64_t> variance_shape_;
|
||||
|
||||
protected:
|
||||
size_t getSerializationSize() override {
|
||||
return getBaseSerializationSize() + SerializedSize(bias_) +
|
||||
SerializedSize(scale_) + SerializedSize(begin_norm_axis_) +
|
||||
SerializedSize(eps_) + SerializedSize(mean_shape_) +
|
||||
SerializedSize(variance_shape_) + SerializedSize(getPluginType());
|
||||
}
|
||||
|
||||
// TRT will call this func when we need to serialize the configuration of
|
||||
// tensorrt.
|
||||
// It should not be called by users.
|
||||
void serialize(void *buffer) override {
|
||||
SerializeValue(&buffer, getPluginType());
|
||||
serializeBase(buffer);
|
||||
SerializeValue(&buffer, bias_);
|
||||
SerializeValue(&buffer, scale_);
|
||||
SerializeValue(&buffer, begin_norm_axis_);
|
||||
SerializeValue(&buffer, eps_);
|
||||
SerializeValue(&buffer, mean_shape_);
|
||||
SerializeValue(&buffer, variance_shape_);
|
||||
}
|
||||
|
||||
public:
|
||||
LayerNormPlugin(const float *bias, const int bias_num, const float *scale,
|
||||
const int scale_num, int begin_norm_axis, float eps,
|
||||
std::vector<int64_t> mean_shape,
|
||||
std::vector<int64_t> variance_shape)
|
||||
: begin_norm_axis_(begin_norm_axis),
|
||||
eps_(eps),
|
||||
mean_shape_(mean_shape),
|
||||
variance_shape_(variance_shape) {
|
||||
bias_.resize(bias_num);
|
||||
scale_.resize(scale_num);
|
||||
std::copy(bias, bias + bias_num, bias_.data());
|
||||
std::copy(scale, scale + scale_num, scale_.data());
|
||||
}
|
||||
|
||||
// It was used for tensorrt deserialization.
|
||||
// It should not be called by users.
|
||||
LayerNormPlugin(void const *serialData, size_t serialLength) {
|
||||
deserializeBase(serialData, serialLength);
|
||||
DeserializeValue(&serialData, &serialLength, &bias_);
|
||||
DeserializeValue(&serialData, &serialLength, &scale_);
|
||||
DeserializeValue(&serialData, &serialLength, &begin_norm_axis_);
|
||||
DeserializeValue(&serialData, &serialLength, &eps_);
|
||||
DeserializeValue(&serialData, &serialLength, &mean_shape_);
|
||||
DeserializeValue(&serialData, &serialLength, &variance_shape_);
|
||||
}
|
||||
~LayerNormPlugin() {}
|
||||
int initialize() override;
|
||||
|
||||
LayerNormPlugin *clone() const override {
|
||||
return new LayerNormPlugin(bias_.data(), bias_.size(), scale_.data(),
|
||||
scale_.size(), begin_norm_axis_, eps_,
|
||||
mean_shape_, variance_shape_);
|
||||
}
|
||||
|
||||
const char *getPluginType() const override { return "layer_norm_plugin"; }
|
||||
int getNbOutputs() const override { return 1; }
|
||||
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
|
||||
int nbInputDims) override;
|
||||
int enqueue(int batchSize, const void *const *inputs, void **outputs,
|
||||
void *workspace, cudaStream_t stream) override;
|
||||
};
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,85 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
|
||||
TEST(TensorRT, split_converter) {
|
||||
AnalysisConfig config;
|
||||
int batch_size = 1;
|
||||
config.SetModel(FLAGS_infer_model);
|
||||
config.EnableUseGpu(1200, 0);
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
config.EnableTensorRtEngine(1 << 30, batch_size, 10,
|
||||
AnalysisConfig::Precision::kFloat32, false,
|
||||
false);
|
||||
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
|
||||
int64_t i0[128] = {
|
||||
96, 54, 78, 37, 106, 35, 122, 33, 95, 63, 81, 60, 65, 68, 45, 96,
|
||||
117, 61, 43, 15, 12, 64, 91, 100, 90, 74, 99, 23, 22, 91, 83, 13,
|
||||
28, 71, 59, 15, 40, 26, 66, 18, 31, 87, 85, 11, 55, 67, 28, 126,
|
||||
7, 89, 39, 67, 88, 29, 66, 38, 98, 1, 66, 38, 95, 56, 48, 95,
|
||||
9, 38, 90, 82, 101, 6, 75, 46, 42, 89, 98, 12, 6, 101, 82, 55,
|
||||
81, 113, 33, 91, 44, 73, 41, 39, 12, 113, 13, 86, 36, 91, 53, 68,
|
||||
103, 67, 65, 92, 27, 76, 24, 107, 54, 94, 63, 10, 15, 32, 91, 45,
|
||||
37, 126, 49, 118, 73, 127, 122, 119, 28, 96, 92, 79, 21, 90, 11, 40};
|
||||
int64_t i1[128] = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
|
||||
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
||||
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
||||
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
|
||||
75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
|
||||
90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,
|
||||
105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
|
||||
120, 121, 122, 123, 124, 125, 126, 127};
|
||||
int64_t i2[128] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
float i3[128 * 128] = {0.0};
|
||||
int64_t i4[1] = {0};
|
||||
|
||||
auto input_names = predictor->GetInputNames();
|
||||
|
||||
auto input_t0 = predictor->GetInputTensor(input_names[0]);
|
||||
input_t0->Reshape({batch_size, 128, 1});
|
||||
input_t0->copy_from_cpu(i0);
|
||||
auto input_t1 = predictor->GetInputTensor(input_names[1]);
|
||||
input_t1->Reshape({batch_size, 128, 1});
|
||||
input_t1->copy_from_cpu(i1);
|
||||
auto input_t2 = predictor->GetInputTensor(input_names[2]);
|
||||
input_t2->Reshape({batch_size, 128, 1});
|
||||
input_t2->copy_from_cpu(i2);
|
||||
auto input_t3 = predictor->GetInputTensor(input_names[3]);
|
||||
input_t3->Reshape({batch_size, 128, 128});
|
||||
input_t3->copy_from_cpu(i3);
|
||||
auto input_t4 = predictor->GetInputTensor(input_names[4]);
|
||||
input_t4->Reshape({batch_size, 1});
|
||||
input_t4->copy_from_cpu(i4);
|
||||
|
||||
ASSERT_TRUE(predictor->ZeroCopyRun());
|
||||
}
|
||||
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue