TensorRT中ernie模型推理性能优化,支持变长输入 (#28367)
* fp16 result ok * change -DWITH_NVINFER_PLUGIN toconfig.EnableTensorRtOSS * auto detect special slice op converter for ernie with trt oss * ernie oss only support fp16 * fix special_slice_plugin serialize bug * matmul in tensorrt ok * ernie unittest ok * add matmul tensorrt unittest * remove demo codeTCChenlong-patch-1
parent
84cc61b2cd
commit
ea851796e5
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
class Scope;
|
||||||
|
namespace proto {
|
||||||
|
class OpDesc;
|
||||||
|
} // namespace proto
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* MatMulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
|
||||||
|
*/
|
||||||
|
class MatMulOpConverter : public OpConverter {
|
||||||
|
public:
|
||||||
|
void operator()(const framework::proto::OpDesc& op,
|
||||||
|
const framework::Scope& scope, bool test_mode) override {
|
||||||
|
VLOG(3) << "convert a fluid matmul op to tensorrt mul layer without bias";
|
||||||
|
|
||||||
|
framework::OpDesc op_desc(op, nullptr);
|
||||||
|
// Declare inputs
|
||||||
|
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
|
||||||
|
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
|
||||||
|
|
||||||
|
bool transpose_X = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_X"));
|
||||||
|
bool transpose_Y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y"));
|
||||||
|
|
||||||
|
auto* layer = TRT_ENGINE_ADD_LAYER(
|
||||||
|
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), transpose_X,
|
||||||
|
*const_cast<nvinfer1::ITensor*>(input2), transpose_Y);
|
||||||
|
|
||||||
|
float alpha = BOOST_GET_CONST(float, op_desc.GetAttr("alpha"));
|
||||||
|
auto output_name = op_desc.Output("Out")[0];
|
||||||
|
if (fabs(alpha - 1.0) < std::numeric_limits<float>::epsilon()) {
|
||||||
|
engine_->SetITensor(output_name, layer->getOutput(0));
|
||||||
|
} else {
|
||||||
|
auto create_weights = [&](float data, const std::string &type) -> float* {
|
||||||
|
std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
|
||||||
|
tmp_tensor->Resize({1});
|
||||||
|
auto* tmp_data = tmp_tensor->mutable_data<float>(platform::CPUPlace());
|
||||||
|
tmp_data[0] = data;
|
||||||
|
engine_->SetWeights(output_name + "_add_scale_op_" + type,
|
||||||
|
std::move(tmp_tensor));
|
||||||
|
return tmp_data;
|
||||||
|
};
|
||||||
|
float* alpha_data = create_weights(alpha, "alpha");
|
||||||
|
float* shift_data = create_weights(0.0, "shift");
|
||||||
|
float* power_data = create_weights(1.0, "power");
|
||||||
|
TensorRTEngine::Weight nv_alpha{nvinfer1::DataType::kFLOAT,
|
||||||
|
static_cast<void*>(alpha_data), 1};
|
||||||
|
TensorRTEngine::Weight nv_shift{nvinfer1::DataType::kFLOAT,
|
||||||
|
static_cast<void*>(shift_data), 1};
|
||||||
|
TensorRTEngine::Weight nv_power{nvinfer1::DataType::kFLOAT,
|
||||||
|
static_cast<void*>(power_data), 1};
|
||||||
|
auto* scale_layer = TRT_ENGINE_ADD_LAYER(
|
||||||
|
engine_, Scale, *layer->getOutput(0),
|
||||||
|
nvinfer1::ScaleMode::kUNIFORM,
|
||||||
|
nv_shift.get(), nv_alpha.get(), nv_power.get());
|
||||||
|
engine_->SetITensor(output_name, scale_layer->getOutput(0));
|
||||||
|
}
|
||||||
|
if (test_mode) { // the test framework can not determine which is the
|
||||||
|
// output, so place the declaration inside.
|
||||||
|
engine_->DeclareOutput(output_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_TRT_OP_CONVERTER(matmul, MatMulOpConverter);
|
@ -1,61 +0,0 @@
|
|||||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace framework {
|
|
||||||
class Scope;
|
|
||||||
namespace proto {
|
|
||||||
class OpDesc;
|
|
||||||
} // namespace proto
|
|
||||||
} // namespace framework
|
|
||||||
} // namespace paddle
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace inference {
|
|
||||||
namespace tensorrt {
|
|
||||||
|
|
||||||
/*
|
|
||||||
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
|
|
||||||
*/
|
|
||||||
class MulOpConverter : public OpConverter {
|
|
||||||
public:
|
|
||||||
void operator()(const framework::proto::OpDesc& op,
|
|
||||||
const framework::Scope& scope, bool test_mode) override {
|
|
||||||
VLOG(3) << "convert a fluid mul op to tensorrt mul layer without bias";
|
|
||||||
|
|
||||||
framework::OpDesc op_desc(op, nullptr);
|
|
||||||
// Declare inputs
|
|
||||||
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
|
|
||||||
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
|
|
||||||
// Both the input1 and input2 do not need transpose.
|
|
||||||
auto* layer = TRT_ENGINE_ADD_LAYER(
|
|
||||||
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
|
|
||||||
*const_cast<nvinfer1::ITensor*>(input2), false);
|
|
||||||
|
|
||||||
auto output_name = op_desc.Output("Out")[0];
|
|
||||||
engine_->SetITensor(output_name, layer->getOutput(0));
|
|
||||||
if (test_mode) { // the test framework can not determine which is the
|
|
||||||
// output, so place the declaration inside.
|
|
||||||
engine_->DeclareOutput(output_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tensorrt
|
|
||||||
} // namespace inference
|
|
||||||
} // namespace paddle
|
|
||||||
|
|
||||||
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
|
|
@ -0,0 +1,177 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {}
|
||||||
|
|
||||||
|
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data,
|
||||||
|
size_t serial_length) {}
|
||||||
|
|
||||||
|
SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const {
|
||||||
|
return new SpecialSlicePluginDynamic();
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* SpecialSlicePluginDynamic::getPluginType() const {
|
||||||
|
return "special_slice_plugin";
|
||||||
|
}
|
||||||
|
|
||||||
|
int SpecialSlicePluginDynamic::getNbOutputs() const { return 1; }
|
||||||
|
|
||||||
|
int SpecialSlicePluginDynamic::initialize() { return 0; }
|
||||||
|
|
||||||
|
size_t SpecialSlicePluginDynamic::getSerializationSize() const {
|
||||||
|
size_t serialize_size = 0;
|
||||||
|
return serialize_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpecialSlicePluginDynamic::serialize(void* buffer) const {}
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
|
||||||
|
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
|
||||||
|
nvinfer1::IExprBuilder& expr_builder) {
|
||||||
|
nvinfer1::DimsExprs output(inputs[0]);
|
||||||
|
auto one = expr_builder.constant(1);
|
||||||
|
output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
|
||||||
|
*inputs[1].d[0], *one);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpecialSlicePluginDynamic::configurePlugin(
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}
|
||||||
|
|
||||||
|
size_t SpecialSlicePluginDynamic::getWorkspaceSize(
|
||||||
|
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpecialSlicePluginDynamic::destroy() { delete this; }
|
||||||
|
|
||||||
|
void SpecialSlicePluginDynamic::terminate() {}
|
||||||
|
|
||||||
|
bool SpecialSlicePluginDynamic::supportsFormatCombination(
|
||||||
|
int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs,
|
||||||
|
int nb_outputs) {
|
||||||
|
if (pos == 0) // slice tensor
|
||||||
|
return (desc[pos].type == nvinfer1::DataType::kHALF &&
|
||||||
|
desc[pos].format ==
|
||||||
|
nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type ==
|
||||||
|
// nvinfer1::DataType::kFLOAT);
|
||||||
|
|
||||||
|
if (pos == 1) // cu_seqlen
|
||||||
|
return (desc[pos].type == nvinfer1::DataType::kINT32 &&
|
||||||
|
desc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||||
|
|
||||||
|
return (desc[pos].type == nvinfer1::DataType::kHALF &&
|
||||||
|
desc[pos].format ==
|
||||||
|
nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type ==
|
||||||
|
// nvinfer1::DataType::kFLOAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
|
||||||
|
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
|
||||||
|
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
|
||||||
|
"The index should be equal to 0"));
|
||||||
|
return input_types[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void SpecialSliceKernel(const T* slice_input,
|
||||||
|
const int32_t* cu_seqlens, T* output) {
|
||||||
|
const int hidden = blockDim.x;
|
||||||
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
|
output[batch * hidden + threadIdx.x] =
|
||||||
|
slice_input[cu_seqlens[batch] * hidden + threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
int SpecialSlicePluginDynamic::enqueue(
|
||||||
|
const nvinfer1::PluginTensorDesc* input_desc,
|
||||||
|
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
|
||||||
|
void* const* outputs, void* workspace, cudaStream_t stream) {
|
||||||
|
auto input_dims = input_desc[0].dims; // (sum(S), 768, 1, 1)
|
||||||
|
auto out_dims = output_desc[0].dims; // (batch, 768, 1, 1)
|
||||||
|
|
||||||
|
assert(input_desc[0].type == nvinfer1::DataType::kHALF);
|
||||||
|
|
||||||
|
const int32_t hidden = input_dims.d[1];
|
||||||
|
const int num_blocks = out_dims.d[0]; // batch size
|
||||||
|
const int num_threads = hidden;
|
||||||
|
|
||||||
|
const half* slice_input = static_cast<const half*>(inputs[0]);
|
||||||
|
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
|
||||||
|
half* output = static_cast<half*>(outputs[0]);
|
||||||
|
|
||||||
|
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
|
||||||
|
slice_input, cu_seqlens, output);
|
||||||
|
|
||||||
|
return cudaGetLastError() != cudaSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {}
|
||||||
|
|
||||||
|
const char* SpecialSlicePluginDynamicCreator::getPluginName() const {
|
||||||
|
return "special_slice_plugin";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const {
|
||||||
|
return "1";
|
||||||
|
}
|
||||||
|
|
||||||
|
const nvinfer1::PluginFieldCollection*
|
||||||
|
SpecialSlicePluginDynamicCreator::getFieldNames() {
|
||||||
|
return &field_collection_;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin(
|
||||||
|
const char* name, const nvinfer1::PluginFieldCollection* fc) {
|
||||||
|
return new SpecialSlicePluginDynamic();
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin(
|
||||||
|
const char* name, const void* serial_data, size_t serial_length) {
|
||||||
|
auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length);
|
||||||
|
return plugin;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpecialSlicePluginDynamicCreator::setPluginNamespace(
|
||||||
|
const char* lib_namespace) {
|
||||||
|
plugin_namespace_ = lib_namespace;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const {
|
||||||
|
return plugin_namespace_.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,96 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <cassert>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/tensor.h"
|
||||||
|
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
#if IS_TRT_VERSION_GE(6000)
|
||||||
|
class SpecialSlicePluginDynamic : public DynamicPluginTensorRT {
|
||||||
|
public:
|
||||||
|
SpecialSlicePluginDynamic();
|
||||||
|
SpecialSlicePluginDynamic(void const* serial_data, size_t serial_length);
|
||||||
|
~SpecialSlicePluginDynamic();
|
||||||
|
nvinfer1::IPluginV2DynamicExt* clone() const override;
|
||||||
|
nvinfer1::DimsExprs getOutputDimensions(
|
||||||
|
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||||
|
nvinfer1::IExprBuilder& exprBuilder) override;
|
||||||
|
bool supportsFormatCombination(int pos,
|
||||||
|
const nvinfer1::PluginTensorDesc* inOut,
|
||||||
|
int nbInputs, int nbOutputs) override;
|
||||||
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* out,
|
||||||
|
int nbOutputs) override;
|
||||||
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputs,
|
||||||
|
int nbOutputs) const override;
|
||||||
|
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||||
|
const void* const* inputs, void* const* outputs, void* workspace,
|
||||||
|
cudaStream_t stream) override;
|
||||||
|
|
||||||
|
nvinfer1::DataType getOutputDataType(int index,
|
||||||
|
const nvinfer1::DataType* inputTypes,
|
||||||
|
int nbInputs) const override;
|
||||||
|
|
||||||
|
const char* getPluginType() const override;
|
||||||
|
int getNbOutputs() const override;
|
||||||
|
int initialize() override;
|
||||||
|
void terminate() override;
|
||||||
|
size_t getSerializationSize() const override;
|
||||||
|
void serialize(void* buffer) const override;
|
||||||
|
void destroy() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int axis_;
|
||||||
|
int num_stack_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SpecialSlicePluginDynamicCreator : public nvinfer1::IPluginCreator {
|
||||||
|
public:
|
||||||
|
SpecialSlicePluginDynamicCreator();
|
||||||
|
const char* getPluginName() const override;
|
||||||
|
const char* getPluginVersion() const override;
|
||||||
|
const nvinfer1::PluginFieldCollection* getFieldNames() override;
|
||||||
|
nvinfer1::IPluginV2* createPlugin(
|
||||||
|
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
|
||||||
|
nvinfer1::IPluginV2* deserializePlugin(const char* name,
|
||||||
|
const void* serial_data,
|
||||||
|
size_t serial_length) override;
|
||||||
|
void setPluginNamespace(const char* lib_namespace) override;
|
||||||
|
const char* getPluginNamespace() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string plugin_namespace_;
|
||||||
|
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
|
||||||
|
std::vector<nvinfer1::PluginField> plugin_attributes_;
|
||||||
|
};
|
||||||
|
REGISTER_TRT_PLUGIN_V2(SpecialSlicePluginDynamicCreator);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue