[Paddle-TRT] Stack op plugin (#25605)
* add stack_op to CMakeLists * add dim=3 support for scale op * add trt stack op, test=develop * remove debug message * add stack plugin serialize * remove slice, scale op, will add later * enhence error message * revise trt ernie test to conver the stack op CI testi, test=develop * add stack op serialization * fix test shape after adding stack op * remove slice op, will add after implementing serialization * roll back to min_graph=5 to avoid using slice op * fix scale op output layer * implement stack op createPlugin * use workspace and move the defination to .cu * move stack plugin creator definition to .cu, test=developrevert-26856-strategy_example2
parent
60ffc22026
commit
ad6e3dd69c
@ -0,0 +1,75 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
/*
|
||||
* Stack converter from fluid to tensorRT.
|
||||
*/
|
||||
class StackOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(4) << "convert fluid stack op to tensorrt stack layer";
|
||||
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
auto input = op_desc.Input("X");
|
||||
int input_num = input.size();
|
||||
nvinfer1::ITensor** inputs =
|
||||
(nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*));
|
||||
|
||||
for (int i = 0; i < input_num; ++i) {
|
||||
inputs[i] = engine_->GetITensor(input[i]);
|
||||
}
|
||||
|
||||
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
|
||||
if (axis < 0) {
|
||||
axis = axis + inputs[0]->getDimensions().nbDims + 1;
|
||||
}
|
||||
|
||||
nvinfer1::ILayer* layer = nullptr;
|
||||
if (engine_->with_dynamic_shape()) {
|
||||
#if IS_TRT_VERSION_GE(6000)
|
||||
plugin::StackPluginDynamic* plugin =
|
||||
new plugin::StackPluginDynamic(axis, input_num);
|
||||
layer = engine_->AddPluginV2(inputs, input_num, plugin);
|
||||
assert(layer != nullptr);
|
||||
#else
|
||||
PADDLE_THROW(platform::errors::Fatal(
|
||||
"You are running the TRT Dynamic Shape mode, need to confirm that "
|
||||
"your TRT version is no less than 6.0"));
|
||||
#endif
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::Fatal(
|
||||
"You are running the Ernie(Bert) model in static"
|
||||
"shape mode, which is not supported for the time being.\n"
|
||||
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
|
||||
" to set the shape information to run the dynamic shape mode."));
|
||||
}
|
||||
auto output_name = op_desc.Output("Y").front();
|
||||
RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode);
|
||||
free(inputs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(stack, StackOpConverter);
|
@ -1,7 +1,8 @@
|
||||
nv_library(tensorrt_plugin
|
||||
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
|
||||
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
|
||||
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
|
||||
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
|
||||
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
|
||||
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
|
||||
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
|
||||
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
|
||||
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
|
||||
hard_swish_op_plugin.cu stack_op_plugin.cu
|
||||
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
|
||||
|
@ -0,0 +1,247 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
#if IS_TRT_VERSION_GE(6000)
|
||||
StackPluginDynamic::StackPluginDynamic(int axis, int num_stack)
|
||||
: axis_(axis), num_stack_(num_stack) {}
|
||||
|
||||
StackPluginDynamic::StackPluginDynamic(void const* serial_data,
|
||||
size_t serial_length) {
|
||||
DeserializeValue(&serial_data, &serial_length, &axis_);
|
||||
DeserializeValue(&serial_data, &serial_length, &num_stack_);
|
||||
}
|
||||
|
||||
StackPluginDynamic::~StackPluginDynamic() {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt* StackPluginDynamic::clone() const {
|
||||
return new StackPluginDynamic(axis_, num_stack_);
|
||||
}
|
||||
|
||||
const char* StackPluginDynamic::getPluginType() const { return "stack_plugin"; }
|
||||
|
||||
int StackPluginDynamic::getNbOutputs() const { return 1; }
|
||||
|
||||
int StackPluginDynamic::initialize() { return 0; }
|
||||
|
||||
size_t StackPluginDynamic::getSerializationSize() const {
|
||||
size_t serialize_size = 0;
|
||||
serialize_size += SerializedSize(axis_);
|
||||
serialize_size += SerializedSize(num_stack_);
|
||||
return serialize_size;
|
||||
}
|
||||
|
||||
void StackPluginDynamic::serialize(void* buffer) const {
|
||||
SerializeValue(&buffer, axis_);
|
||||
SerializeValue(&buffer, num_stack_);
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions(
|
||||
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
|
||||
nvinfer1::IExprBuilder& expr_builder) {
|
||||
nvinfer1::DimsExprs output(inputs[0]);
|
||||
output.nbDims = inputs[0].nbDims + 1;
|
||||
|
||||
for (int i = inputs[0].nbDims; i > axis_; --i) {
|
||||
output.d[i] = inputs[0].d[i - 1];
|
||||
}
|
||||
output.d[axis_] = expr_builder.constant(nb_inputs);
|
||||
return output;
|
||||
}
|
||||
|
||||
void StackPluginDynamic::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}
|
||||
|
||||
size_t StackPluginDynamic::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
|
||||
return num_stack_ * sizeof(uintptr_t);
|
||||
}
|
||||
|
||||
void StackPluginDynamic::destroy() { delete this; }
|
||||
|
||||
void StackPluginDynamic::terminate() {}
|
||||
|
||||
bool StackPluginDynamic::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
|
||||
int nb_outputs) {
|
||||
PADDLE_ENFORCE_NOT_NULL(
|
||||
in_out, platform::errors::InvalidArgument(
|
||||
"The input of stack plugin should not be nullptr."));
|
||||
|
||||
PADDLE_ENFORCE_LT(
|
||||
pos, nb_inputs + nb_outputs,
|
||||
platform::errors::InvalidArgument("The pos(%d) should be less than the "
|
||||
"num(%d) of the input and the output.",
|
||||
pos, nb_inputs + nb_outputs));
|
||||
|
||||
const nvinfer1::PluginTensorDesc& in = in_out[pos];
|
||||
if (pos == 0) {
|
||||
#ifdef SUPPORTS_CUDA_FP16
|
||||
return (in.type == nvinfer1::DataType::kFLOAT ||
|
||||
in.type == nvinfer1::DataType::kHALF) &&
|
||||
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
||||
#else
|
||||
return (in.type == nvinfer1::DataType::kFLOAT) &&
|
||||
(in.format == nvinfer1::TensorFormat::kLINEAR);
|
||||
#endif
|
||||
}
|
||||
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
|
||||
// output
|
||||
return in.type == prev.type && in.format == prev.format;
|
||||
}
|
||||
|
||||
nvinfer1::DataType StackPluginDynamic::getOutputDataType(
|
||||
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
|
||||
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
|
||||
"The index should be equal to 0"));
|
||||
return input_types[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void StackKernel(const T* const* input, T* output, int num_stack,
|
||||
int base_unit) {
|
||||
int stack_id = blockIdx.x;
|
||||
int lead_id = blockIdx.y;
|
||||
|
||||
for (int i = threadIdx.x; i < base_unit; i += blockDim.x) {
|
||||
output[lead_id * num_stack * base_unit + stack_id * base_unit + i] =
|
||||
input[stack_id][lead_id * base_unit + i];
|
||||
}
|
||||
}
|
||||
|
||||
int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
|
||||
const nvinfer1::PluginTensorDesc* output_desc,
|
||||
const void* const* inputs, void* const* outputs,
|
||||
void* workspace, cudaStream_t stream) {
|
||||
auto input_dims = input_desc[0].dims; // (batch, seq, seq)
|
||||
auto out_dims = output_desc[0].dims; // (batch, num_head, seq, seq)
|
||||
auto out_num_dims = out_dims.nbDims;
|
||||
|
||||
int base_unit = 1;
|
||||
for (int i = axis_ + 1; i < out_num_dims; ++i) {
|
||||
PADDLE_ENFORCE_GT(out_dims.d[i], 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input dimensions should be greater than 0"));
|
||||
base_unit *= out_dims.d[i];
|
||||
}
|
||||
|
||||
int lead_unit = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
PADDLE_ENFORCE_GT(out_dims.d[i], 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input dimensions should be greater than 0"));
|
||||
lead_unit *= out_dims.d[i];
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
out_dims.d[axis_], num_stack_,
|
||||
platform::errors::InvalidArgument("number of stack axis should be same"));
|
||||
|
||||
cudaMemcpyAsync(workspace, reinterpret_cast<const void* const>(inputs),
|
||||
sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice,
|
||||
stream);
|
||||
|
||||
const int num_stacks = out_dims.d[axis_];
|
||||
dim3 num_blocks(num_stacks, lead_unit);
|
||||
const int num_threads = 256;
|
||||
auto infer_type = input_desc[0].type;
|
||||
|
||||
if (infer_type == nvinfer1::DataType::kFLOAT) {
|
||||
float* output = static_cast<float*>(outputs[0]);
|
||||
StackKernel<float><<<num_blocks, num_threads, 0, stream>>>(
|
||||
reinterpret_cast<const float* const*>(workspace), output, num_stacks,
|
||||
base_unit);
|
||||
} else if (infer_type == nvinfer1::DataType::kHALF) {
|
||||
#ifdef SUPPORTS_CUDA_FP16
|
||||
__half* output = static_cast<__half*>(outputs[0]);
|
||||
StackKernel<__half><<<num_blocks, num_threads, 0, stream>>>(
|
||||
reinterpret_cast<const __half* const*>(workspace), output, num_stacks,
|
||||
base_unit);
|
||||
#else
|
||||
PADDLE_THROW(platform::errors::Fatal(
|
||||
"The cuda archs you specific should greater than 600."));
|
||||
#endif
|
||||
} else {
|
||||
PADDLE_THROW(
|
||||
platform::errors::Fatal("The Stack TRT Plugin's input type only "
|
||||
"support float or half currently."));
|
||||
}
|
||||
return cudaGetLastError() != cudaSuccess;
|
||||
}
|
||||
|
||||
StackPluginDynamicCreator::StackPluginDynamicCreator() {}
|
||||
|
||||
const char* StackPluginDynamicCreator::getPluginName() const {
|
||||
return "stack_plugin";
|
||||
}
|
||||
|
||||
const char* StackPluginDynamicCreator::getPluginVersion() const { return "1"; }
|
||||
|
||||
const nvinfer1::PluginFieldCollection*
|
||||
StackPluginDynamicCreator::getFieldNames() {
|
||||
return &field_collection_;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* StackPluginDynamicCreator::createPlugin(
|
||||
const char* name, const nvinfer1::PluginFieldCollection* fc) {
|
||||
int axis = -1;
|
||||
int num_stack = -1;
|
||||
|
||||
for (int i = 0; i < fc->nbFields; ++i) {
|
||||
const std::string name(fc->fields[i].name);
|
||||
if (name == "axis") {
|
||||
axis = static_cast<const int*>(fc->fields[i].data)[0];
|
||||
} else if (name == "num_stack") {
|
||||
num_stack = static_cast<const int*>(fc->fields[i].data)[0];
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::Fatal("Meet an unknown plugin field '" +
|
||||
name +
|
||||
"' when creating stack op plugin."));
|
||||
}
|
||||
}
|
||||
return new StackPluginDynamic(axis, num_stack);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2* StackPluginDynamicCreator::deserializePlugin(
|
||||
const char* name, const void* serial_data, size_t serial_length) {
|
||||
auto plugin = new StackPluginDynamic(serial_data, serial_length);
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void StackPluginDynamicCreator::setPluginNamespace(const char* lib_namespace) {
|
||||
plugin_namespace_ = lib_namespace;
|
||||
}
|
||||
|
||||
const char* StackPluginDynamicCreator::getPluginNamespace() const {
|
||||
return plugin_namespace_.c_str();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,96 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <stdio.h>
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
#if IS_TRT_VERSION_GE(6000)
|
||||
class StackPluginDynamic : public DynamicPluginTensorRT {
|
||||
public:
|
||||
explicit StackPluginDynamic(int axis, int num_stack);
|
||||
StackPluginDynamic(void const* serial_data, size_t serial_length);
|
||||
~StackPluginDynamic();
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc* inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* out,
|
||||
int nbOutputs) override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs,
|
||||
int nbOutputs) const override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType* inputTypes,
|
||||
int nbInputs) const override;
|
||||
|
||||
const char* getPluginType() const override;
|
||||
int getNbOutputs() const override;
|
||||
int initialize() override;
|
||||
void terminate() override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void* buffer) const override;
|
||||
void destroy() override;
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
int num_stack_;
|
||||
};
|
||||
|
||||
class StackPluginDynamicCreator : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
StackPluginDynamicCreator();
|
||||
const char* getPluginName() const override;
|
||||
const char* getPluginVersion() const override;
|
||||
const nvinfer1::PluginFieldCollection* getFieldNames() override;
|
||||
nvinfer1::IPluginV2* createPlugin(
|
||||
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
|
||||
nvinfer1::IPluginV2* deserializePlugin(const char* name,
|
||||
const void* serial_data,
|
||||
size_t serial_length) override;
|
||||
void setPluginNamespace(const char* lib_namespace) override;
|
||||
const char* getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
std::string plugin_namespace_;
|
||||
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
|
||||
std::vector<nvinfer1::PluginField> plugin_attributes_;
|
||||
};
|
||||
REGISTER_TRT_PLUGIN_V2(StackPluginDynamicCreator);
|
||||
#endif
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue