[Paddle-TRT]: Ernie Dynamic shape support. (#23138)
* add dynamic plugin support. test=develop * change emb eltwise layernorm to math function test=develop * add emb eltwise layernorm test=develop * can run dynamic shape ernie test=develop * fix ci test=develop * add ut for trt ernie dynamic test=develop * refine dynamic shape c++ interface. test=develop * fix comments test=develop * fix comments test=developrevert-23830-2.0-beta
parent
d0413e58d3
commit
430b0099c9
@ -0,0 +1,115 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/helper.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
class EmbEltwiseLayerNormOpConverter : public OpConverter {
|
||||||
|
public:
|
||||||
|
void operator()(const framework::proto::OpDesc& op,
|
||||||
|
const framework::Scope& scope, bool test_mode) override {
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
VLOG(4) << "convert fluid swish op to tensorrt layer";
|
||||||
|
|
||||||
|
framework::OpDesc op_desc(op, nullptr);
|
||||||
|
auto id_names = op_desc.Input("Ids");
|
||||||
|
auto emb_names = op_desc.Input("Embs");
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(id_names.size(), emb_names.size(),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The id and emb size of fused EmbEltwiseLayerNormOp "
|
||||||
|
"should be same "));
|
||||||
|
int input_num = id_names.size();
|
||||||
|
|
||||||
|
// Declare inputs
|
||||||
|
std::vector<nvinfer1::ITensor*> input_ids;
|
||||||
|
for (int i = 0; i < input_num; i++) {
|
||||||
|
input_ids.push_back(engine_->GetITensor(id_names[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float*> input_embs;
|
||||||
|
std::vector<int> emb_sizes;
|
||||||
|
|
||||||
|
// get the presistable var's data
|
||||||
|
auto get_persistable_data = [&](const std::string& var_name,
|
||||||
|
framework::DDim* dims) -> float* {
|
||||||
|
auto* temp_var = scope.FindVar(var_name);
|
||||||
|
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
|
||||||
|
(*dims) = temp_tensor->dims();
|
||||||
|
|
||||||
|
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
|
||||||
|
return temp_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
int hidden = 0;
|
||||||
|
for (int i = 0; i < input_num; i++) {
|
||||||
|
framework::DDim emb_dims;
|
||||||
|
float* emb_data = get_persistable_data(emb_names[i], &emb_dims);
|
||||||
|
int64_t emb_size = framework::product(emb_dims);
|
||||||
|
input_embs.push_back(emb_data);
|
||||||
|
emb_sizes.push_back(emb_size);
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
emb_dims.size(), 2,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The fused EmbEltwiseLayerNorm's emb should be 2 dims."));
|
||||||
|
hidden = emb_dims[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::DDim bias_dims, scale_dims;
|
||||||
|
|
||||||
|
auto* bias =
|
||||||
|
get_persistable_data(op_desc.Input("Bias").front(), &bias_dims);
|
||||||
|
auto* scale =
|
||||||
|
get_persistable_data(op_desc.Input("Scale").front(), &scale_dims);
|
||||||
|
int64_t bias_size = framework::product(bias_dims);
|
||||||
|
int64_t scale_size = framework::product(scale_dims);
|
||||||
|
float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
|
||||||
|
nvinfer1::ILayer* layer = nullptr;
|
||||||
|
|
||||||
|
if (engine_->with_dynamic_shape()) {
|
||||||
|
plugin::EmbEltwiseLayernormPluginDynamic* plugin =
|
||||||
|
new plugin::EmbEltwiseLayernormPluginDynamic(input_embs, bias, scale,
|
||||||
|
emb_sizes, bias_size,
|
||||||
|
scale_size, hidden, eps);
|
||||||
|
layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin);
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(platform::errors::Fatal(
|
||||||
|
"You are running the Ernie(Bert) model in static"
|
||||||
|
"shape mode, which is not supported for the time being.\n"
|
||||||
|
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
|
||||||
|
" to set the shape information to run the dynamic shape mode."));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_name = op_desc.Output("Out")[0];
|
||||||
|
RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name},
|
||||||
|
test_mode);
|
||||||
|
#else
|
||||||
|
PADDLE_THROW(platform::errors::Fatal(
|
||||||
|
"You are running the TRT Dynamic Shape mode, need to confirm that "
|
||||||
|
"your TRT version is no less than 6.0"));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_TRT_OP_CONVERTER(fused_embedding_eltwise_layernorm,
|
||||||
|
EmbEltwiseLayerNormOpConverter);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
class SkipLayerNormOpConverter : public OpConverter {
|
||||||
|
public:
|
||||||
|
void operator()(const framework::proto::OpDesc& op,
|
||||||
|
const framework::Scope& scope, bool test_mode) override {
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
VLOG(4) << "convert fused skip layernorm op to tensorrt layer";
|
||||||
|
framework::OpDesc op_desc(op, nullptr);
|
||||||
|
// Declare inputs
|
||||||
|
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||||
|
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
|
||||||
|
std::vector<nvinfer1::ITensor*> inputs;
|
||||||
|
inputs.push_back(input1);
|
||||||
|
inputs.push_back(input2);
|
||||||
|
|
||||||
|
auto get_persistable_data = [&](const std::string& arg_name,
|
||||||
|
framework::DDim* dims) -> float* {
|
||||||
|
std::string var_name = op_desc.Input(arg_name).front();
|
||||||
|
auto* temp_var = scope.FindVar(var_name);
|
||||||
|
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
|
||||||
|
(*dims) = temp_tensor->dims();
|
||||||
|
|
||||||
|
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
|
||||||
|
return temp_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
framework::DDim bias_dims, scale_dims;
|
||||||
|
auto* bias = get_persistable_data("Bias", &bias_dims);
|
||||||
|
auto* scale = get_persistable_data("Scale", &scale_dims);
|
||||||
|
float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
|
||||||
|
int bias_size = framework::product(bias_dims);
|
||||||
|
int scale_size = framework::product(scale_dims);
|
||||||
|
|
||||||
|
nvinfer1::ILayer* layer = nullptr;
|
||||||
|
if (engine_->with_dynamic_shape()) {
|
||||||
|
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
|
||||||
|
plugin::SkipLayerNormPluginDynamic* plugin =
|
||||||
|
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size,
|
||||||
|
scale_size, eps, ban_fp16);
|
||||||
|
layer = engine_->AddPluginV2(inputs.data(), 2, plugin);
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(platform::errors::Fatal(
|
||||||
|
"You are running the Ernie(Bert) model in static"
|
||||||
|
"shape mode, which is not supported for the time being.\n"
|
||||||
|
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
|
||||||
|
" to set the shape information to run the dynamic shape mode."));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_name = op_desc.Output("Out")[0];
|
||||||
|
RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode);
|
||||||
|
#else
|
||||||
|
PADDLE_THROW(platform::errors::Fatal(
|
||||||
|
"You are running the TRT Dynamic Shape mode, need to confirm that "
|
||||||
|
"your TRT version is no less than 6.0"));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_TRT_OP_CONVERTER(skip_layernorm, SkipLayerNormOpConverter);
|
@ -1,5 +1,7 @@
|
|||||||
nv_library(tensorrt_plugin
|
nv_library(tensorrt_plugin
|
||||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
||||||
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
|
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
|
||||||
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu
|
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
|
||||||
DEPS enforce tensorrt_engine prelu tensor)
|
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
|
||||||
|
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu
|
||||||
|
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
|
||||||
|
@ -0,0 +1,180 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cub/cub.cuh> // NOLINT
|
||||||
|
#include <vector>
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "paddle/fluid/framework/tensor.h"
|
||||||
|
#include "paddle/fluid/framework/tensor_util.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||||
|
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
// Dynamic Plugin below.
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
|
||||||
|
int EmbEltwiseLayernormPluginDynamic::initialize() {
|
||||||
|
embs_gpu_.reserve(embs_.size());
|
||||||
|
for (int i = 0; i < embs_.size(); i++) {
|
||||||
|
cudaMalloc(&embs_gpu_[i], sizeof(float) * emb_sizes_[i]);
|
||||||
|
cudaMemcpy(embs_gpu_[i], embs_[i], emb_sizes_[i] * sizeof(float),
|
||||||
|
cudaMemcpyHostToDevice);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
|
||||||
|
cudaMemcpy(bias_gpu_, bias_, bias_size_ * sizeof(float),
|
||||||
|
cudaMemcpyHostToDevice);
|
||||||
|
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
|
||||||
|
cudaMemcpy(scale_gpu_, scale_, scale_size_ * sizeof(float),
|
||||||
|
cudaMemcpyHostToDevice);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t EmbEltwiseLayernormPluginDynamic::getSerializationSize() const {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void EmbEltwiseLayernormPluginDynamic::serialize(void *buffer) const {}
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions(
|
||||||
|
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
|
||||||
|
nvinfer1::IExprBuilder &expr_builder) {
|
||||||
|
PADDLE_ENFORCE_EQ(output_index, 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"There is only one output of the EmbEltwiseLayernorm, "
|
||||||
|
"so the index should be zero,"
|
||||||
|
"but it's (%d)",
|
||||||
|
output_index));
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
nb_inputs, 3,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The Input of the EmbEltwiseLayernorm should be 3, but we found "
|
||||||
|
"it has (%d) inputs",
|
||||||
|
nb_inputs));
|
||||||
|
nvinfer1::DimsExprs ret;
|
||||||
|
ret.nbDims = 5;
|
||||||
|
ret.d[0] = inputs[0].d[0];
|
||||||
|
ret.d[1] = inputs[0].d[1];
|
||||||
|
ret.d[2] = expr_builder.constant(hidden_size_);
|
||||||
|
ret.d[3] = expr_builder.constant(1);
|
||||||
|
ret.d[4] = expr_builder.constant(1);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
|
||||||
|
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
|
||||||
|
int nb_outputs) {
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(
|
||||||
|
in_out, platform::errors::InvalidArgument(
|
||||||
|
"The input of swish plugin shoule not be nullptr."));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_LT(
|
||||||
|
pos, nb_inputs + nb_outputs,
|
||||||
|
platform::errors::InvalidArgument("The pos(%d) should be less than the "
|
||||||
|
"num(%d) of the input and the output.",
|
||||||
|
pos, nb_inputs + nb_outputs));
|
||||||
|
(in_out && pos < (nb_inputs + nb_outputs));
|
||||||
|
|
||||||
|
const nvinfer1::PluginTensorDesc &desc = in_out[pos];
|
||||||
|
if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos == 0) {
|
||||||
|
return desc.type == nvinfer1::DataType::kINT32;
|
||||||
|
}
|
||||||
|
|
||||||
|
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
|
||||||
|
if (pos == 1 || pos == 2) {
|
||||||
|
return desc.type == nvinfer1::DataType::kINT32 &&
|
||||||
|
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos == 3) {
|
||||||
|
return desc.type == nvinfer1::DataType::kFLOAT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
|
||||||
|
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
index, 0, platform::errors::InvalidArgument(
|
||||||
|
"The EmbEltwiseLayernorm Plugin only has one input, so the "
|
||||||
|
"index value should be 0, but get %d.",
|
||||||
|
index));
|
||||||
|
return nvinfer1::DataType::kFLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
int EmbEltwiseLayernormPluginDynamic::enqueue(
|
||||||
|
const nvinfer1::PluginTensorDesc *input_desc,
|
||||||
|
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
|
||||||
|
void *const *outputs, void *workspace, cudaStream_t stream) {
|
||||||
|
auto id_dims = input_desc[0].dims;
|
||||||
|
int batch = id_dims.d[0];
|
||||||
|
int seq_len = id_dims.d[1];
|
||||||
|
int input_num = embs_.size();
|
||||||
|
|
||||||
|
framework::Tensor in_ptr_tensor, emb_ptr_tensor;
|
||||||
|
int device_id;
|
||||||
|
cudaGetDevice(&device_id);
|
||||||
|
|
||||||
|
in_ptr_tensor.Resize({input_num});
|
||||||
|
emb_ptr_tensor.Resize({input_num});
|
||||||
|
int64_t *in_ptr_gpu_d =
|
||||||
|
in_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
|
||||||
|
int64_t *emb_ptr_gpu_d =
|
||||||
|
emb_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
|
||||||
|
|
||||||
|
std::vector<int64_t> in_ptr, emb_ptr;
|
||||||
|
for (int i = 0; i < input_num; i++) {
|
||||||
|
in_ptr.push_back(reinterpret_cast<uintptr_t>(inputs[i]));
|
||||||
|
emb_ptr.push_back(reinterpret_cast<uintptr_t>(embs_gpu_[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaMemcpyAsync(in_ptr_gpu_d, in_ptr.data(), sizeof(int64_t) * input_num,
|
||||||
|
cudaMemcpyHostToDevice, stream);
|
||||||
|
cudaMemcpyAsync(emb_ptr_gpu_d, emb_ptr.data(), sizeof(int64_t) * input_num,
|
||||||
|
cudaMemcpyHostToDevice, stream);
|
||||||
|
|
||||||
|
auto out_type = output_desc[0].type;
|
||||||
|
|
||||||
|
const unsigned tpb = 256;
|
||||||
|
const dim3 grid(seq_len, batch, 1);
|
||||||
|
const dim3 block(tpb, 1, 1);
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
out_type == nvinfer1::DataType::kFLOAT, true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The EmbEltwiseLayernorm Plugin only only support fp32 input."));
|
||||||
|
|
||||||
|
float *output_d = static_cast<float *>(outputs[0]);
|
||||||
|
operators::math::EmbEltwiseLayerNormFunctor<float> emb_eltwise_layernorm_func;
|
||||||
|
emb_eltwise_layernorm_func(batch, seq_len, hidden_size_, in_ptr_gpu_d,
|
||||||
|
scale_gpu_, bias_gpu_, emb_ptr_gpu_d, output_d,
|
||||||
|
eps_, input_num, stream);
|
||||||
|
return cudaGetLastError() != cudaSuccess;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,113 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
|
||||||
|
public:
|
||||||
|
explicit EmbEltwiseLayernormPluginDynamic(std::vector<float*> input_embs,
|
||||||
|
float* bias, float* scale,
|
||||||
|
std::vector<int> emb_sizes,
|
||||||
|
int bias_size, int scale_size,
|
||||||
|
int hidden_size, float eps)
|
||||||
|
: embs_(input_embs),
|
||||||
|
bias_(bias),
|
||||||
|
scale_(scale),
|
||||||
|
emb_sizes_(emb_sizes),
|
||||||
|
bias_size_(bias_size),
|
||||||
|
scale_size_(scale_size),
|
||||||
|
hidden_size_(hidden_size),
|
||||||
|
eps_(eps) {}
|
||||||
|
|
||||||
|
EmbEltwiseLayernormPluginDynamic(void const* serialData,
|
||||||
|
size_t serialLength) {}
|
||||||
|
nvinfer1::IPluginV2DynamicExt* clone() const override {
|
||||||
|
return new EmbEltwiseLayernormPluginDynamic(
|
||||||
|
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_, hidden_size_,
|
||||||
|
eps_);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getPluginType() const override {
|
||||||
|
return "fused_embedding_eltwise_layernorm_plugin";
|
||||||
|
}
|
||||||
|
int getNbOutputs() const override { return 1; }
|
||||||
|
int initialize() override;
|
||||||
|
|
||||||
|
size_t getSerializationSize() const override;
|
||||||
|
void serialize(void* buffer) const override;
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs getOutputDimensions(
|
||||||
|
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
|
||||||
|
nvinfer1::IExprBuilder& expr_builder) override;
|
||||||
|
|
||||||
|
bool supportsFormatCombination(int pos,
|
||||||
|
const nvinfer1::PluginTensorDesc* inOut,
|
||||||
|
int nbInputs, int nbOutputs) override;
|
||||||
|
|
||||||
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* out,
|
||||||
|
int nbOutputs) override {}
|
||||||
|
|
||||||
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputs,
|
||||||
|
int nbOutputs) const override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||||
|
const void* const* inputs, void* const* outputs, void* workspace,
|
||||||
|
cudaStream_t stream) override;
|
||||||
|
nvinfer1::DataType getOutputDataType(int index,
|
||||||
|
const nvinfer1::DataType* inputTypes,
|
||||||
|
int nbInputs) const override;
|
||||||
|
|
||||||
|
void destroy() override { delete this; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<float*> embs_;
|
||||||
|
float* bias_;
|
||||||
|
float* scale_;
|
||||||
|
|
||||||
|
// data on devices
|
||||||
|
float* bias_gpu_;
|
||||||
|
float* scale_gpu_;
|
||||||
|
std::vector<float*> embs_gpu_;
|
||||||
|
|
||||||
|
std::vector<int> emb_sizes_;
|
||||||
|
int bias_size_;
|
||||||
|
int scale_size_;
|
||||||
|
int hidden_size_;
|
||||||
|
float eps_;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,95 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
class QkvToContextPluginDynamic : public DynamicPluginTensorRT {
|
||||||
|
public:
|
||||||
|
explicit QkvToContextPluginDynamic(int hidden, int head_number, int head_size,
|
||||||
|
float scale, bool ban_fp16)
|
||||||
|
: hidden_(hidden),
|
||||||
|
head_number_(head_number),
|
||||||
|
head_size_(head_size),
|
||||||
|
scale_(scale),
|
||||||
|
ban_fp16_(ban_fp16) {}
|
||||||
|
|
||||||
|
QkvToContextPluginDynamic(void const* serialData, size_t serialLength) {}
|
||||||
|
nvinfer1::IPluginV2DynamicExt* clone() const override {
|
||||||
|
return new QkvToContextPluginDynamic(hidden_, head_number_, head_size_,
|
||||||
|
scale_, ban_fp16_);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getPluginType() const override { return "qkv_to_context_plugin"; }
|
||||||
|
int getNbOutputs() const override { return 1; }
|
||||||
|
int initialize() override;
|
||||||
|
|
||||||
|
size_t getSerializationSize() const override;
|
||||||
|
void serialize(void* buffer) const override;
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs getOutputDimensions(
|
||||||
|
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
|
||||||
|
nvinfer1::IExprBuilder& expr_builder) override;
|
||||||
|
|
||||||
|
bool supportsFormatCombination(int pos,
|
||||||
|
const nvinfer1::PluginTensorDesc* inOut,
|
||||||
|
int nbInputs, int nbOutputs) override;
|
||||||
|
|
||||||
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* out,
|
||||||
|
int nbOutputs) override {}
|
||||||
|
|
||||||
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputs,
|
||||||
|
int nbOutputs) const override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||||
|
const void* const* inputs, void* const* outputs, void* workspace,
|
||||||
|
cudaStream_t stream) override;
|
||||||
|
nvinfer1::DataType getOutputDataType(int index,
|
||||||
|
const nvinfer1::DataType* inputTypes,
|
||||||
|
int nbInputs) const override;
|
||||||
|
|
||||||
|
void destroy() override { delete this; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int hidden_;
|
||||||
|
int head_number_;
|
||||||
|
int head_size_;
|
||||||
|
float scale_;
|
||||||
|
bool ban_fp16_;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,150 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cub/cub.cuh> // NOLINT
|
||||||
|
#include <vector>
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||||
|
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
// Dynamic Plugin below.
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
|
||||||
|
int SkipLayerNormPluginDynamic::initialize() {
|
||||||
|
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
|
||||||
|
cudaMemcpy(bias_gpu_, bias_, bias_size_ * sizeof(float),
|
||||||
|
cudaMemcpyHostToDevice);
|
||||||
|
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
|
||||||
|
cudaMemcpy(scale_gpu_, scale_, scale_size_ * sizeof(float),
|
||||||
|
cudaMemcpyHostToDevice);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t SkipLayerNormPluginDynamic::getSerializationSize() const { return 0; }
|
||||||
|
|
||||||
|
void SkipLayerNormPluginDynamic::serialize(void *buffer) const {}
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs SkipLayerNormPluginDynamic::getOutputDimensions(
|
||||||
|
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
|
||||||
|
nvinfer1::IExprBuilder &expr_builder) {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
inputs[0].nbDims, 5,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The Input dim of the SkipLayernorm should be 5, but it's (%d) now.",
|
||||||
|
inputs[0].nbDims));
|
||||||
|
return inputs[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SkipLayerNormPluginDynamic::supportsFormatCombination(
|
||||||
|
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
|
||||||
|
int nb_outputs) {
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(
|
||||||
|
in_out, platform::errors::InvalidArgument(
|
||||||
|
"The input of swish plugin shoule not be nullptr."));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_LT(
|
||||||
|
pos, nb_inputs + nb_outputs,
|
||||||
|
platform::errors::InvalidArgument("The pos(%d) should be less than the "
|
||||||
|
"num(%d) of the input and the output.",
|
||||||
|
pos, nb_inputs + nb_outputs));
|
||||||
|
|
||||||
|
const nvinfer1::PluginTensorDesc &in = in_out[pos];
|
||||||
|
if (pos == 0) {
|
||||||
|
#ifdef SUPPORTS_CUDA_FP16
|
||||||
|
if (ban_fp16_) {
|
||||||
|
return (in.type == nvinfer1::DataType::kFLOAT) &&
|
||||||
|
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
||||||
|
} else {
|
||||||
|
return (in.type == nvinfer1::DataType::kFLOAT ||
|
||||||
|
in.type == nvinfer1::DataType::kHALF) &&
|
||||||
|
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
return (in.type == nvinfer1::DataType::kFLOAT) &&
|
||||||
|
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
|
||||||
|
|
||||||
|
if (pos == 1) {
|
||||||
|
return in.type == prev.type && in.format == prev.format;
|
||||||
|
}
|
||||||
|
|
||||||
|
// output
|
||||||
|
return in.type == prev.type && in.format == prev.format;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
|
||||||
|
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
|
||||||
|
PADDLE_ENFORCE_EQ(index, 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The SkipLayerNorm Plugin only has one input, so the "
|
||||||
|
"index value should be 0, but get %d.",
|
||||||
|
index));
|
||||||
|
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
|
||||||
|
input_types[0] == nvinfer1::DataType::kHALF),
|
||||||
|
true, platform::errors::InvalidArgument(
|
||||||
|
"The input type should be half or float"));
|
||||||
|
return input_types[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
int SkipLayerNormPluginDynamic::enqueue(
|
||||||
|
const nvinfer1::PluginTensorDesc *input_desc,
|
||||||
|
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
|
||||||
|
void *const *outputs, void *workspace, cudaStream_t stream) {
|
||||||
|
auto input_dims = input_desc[0].dims;
|
||||||
|
size_t num = ProductDim(input_dims);
|
||||||
|
int hidden = input_dims.d[2];
|
||||||
|
|
||||||
|
auto input_type = input_desc[0].type;
|
||||||
|
if (input_type == nvinfer1::DataType::kFLOAT) {
|
||||||
|
const float *input1 = static_cast<const float *>(inputs[0]);
|
||||||
|
const float *input2 = static_cast<const float *>(inputs[1]);
|
||||||
|
float *output = static_cast<float *>(outputs[0]);
|
||||||
|
operators::math::SkipLayerNormFunctor<float> skip_layer_norm_func;
|
||||||
|
skip_layer_norm_func(num, hidden, input1, input2, scale_gpu_, bias_gpu_,
|
||||||
|
output, eps_, stream);
|
||||||
|
} else if (input_type == nvinfer1::DataType::kHALF) {
|
||||||
|
#ifdef SUPPORTS_CUDA_FP16
|
||||||
|
const half *input1 = static_cast<const half *>(inputs[0]);
|
||||||
|
const half *input2 = static_cast<const half *>(inputs[1]);
|
||||||
|
half *output = static_cast<half *>(outputs[0]);
|
||||||
|
operators::math::SkipLayerNormFunctor<half> skip_layer_norm_func;
|
||||||
|
skip_layer_norm_func(num, hidden, input1, input2, scale_gpu_, bias_gpu_,
|
||||||
|
output, static_cast<half>(eps_), stream);
|
||||||
|
#else
|
||||||
|
PADDLE_THROW(platform::errors::Fatal(
|
||||||
|
"The cuda archs you specific should greater than 600."));
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(platform::errors::Fatal(
|
||||||
|
"The SkipLayerNorm TRT Plugin's input type should be float or half."));
|
||||||
|
}
|
||||||
|
return cudaGetLastError() != cudaSuccess;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,102 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
|
||||||
|
public:
|
||||||
|
explicit SkipLayerNormPluginDynamic(float* bias, float* scale, int bias_size,
|
||||||
|
int scale_size, const float eps,
|
||||||
|
bool ban_fp16)
|
||||||
|
: bias_(bias),
|
||||||
|
scale_(scale),
|
||||||
|
bias_size_(bias_size),
|
||||||
|
scale_size_(scale_size),
|
||||||
|
eps_(eps),
|
||||||
|
ban_fp16_(ban_fp16) {}
|
||||||
|
SkipLayerNormPluginDynamic(void const* serialData, size_t serialLength) {}
|
||||||
|
nvinfer1::IPluginV2DynamicExt* clone() const override {
|
||||||
|
return new SkipLayerNormPluginDynamic(bias_, scale_, bias_size_,
|
||||||
|
scale_size_, eps_, ban_fp16_);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getPluginType() const override { return "skip_layernorm_plugin"; }
|
||||||
|
int getNbOutputs() const override { return 1; }
|
||||||
|
int initialize() override;
|
||||||
|
|
||||||
|
size_t getSerializationSize() const override;
|
||||||
|
void serialize(void* buffer) const override;
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs getOutputDimensions(
|
||||||
|
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
|
||||||
|
nvinfer1::IExprBuilder& expr_builder) override;
|
||||||
|
|
||||||
|
bool supportsFormatCombination(int pos,
|
||||||
|
const nvinfer1::PluginTensorDesc* inOut,
|
||||||
|
int nbInputs, int nbOutputs) override;
|
||||||
|
|
||||||
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* out,
|
||||||
|
int nbOutputs) override {}
|
||||||
|
|
||||||
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputs,
|
||||||
|
int nbOutputs) const override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||||
|
const void* const* inputs, void* const* outputs, void* workspace,
|
||||||
|
cudaStream_t stream) override;
|
||||||
|
nvinfer1::DataType getOutputDataType(int index,
|
||||||
|
const nvinfer1::DataType* inputTypes,
|
||||||
|
int nbInputs) const override;
|
||||||
|
|
||||||
|
void destroy() override { delete this; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
float* bias_;
|
||||||
|
float* scale_;
|
||||||
|
|
||||||
|
float* bias_gpu_;
|
||||||
|
float* scale_gpu_;
|
||||||
|
|
||||||
|
int bias_size_;
|
||||||
|
int scale_size_;
|
||||||
|
|
||||||
|
float eps_;
|
||||||
|
bool ban_fp16_;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue