[Paddle-TRT] roi_align_plugin (#31732)

* add roi_align_plugin

* add roi align unit_test

* add roi align serialization

* remove roi align static plugin because of batch dim issue

* refine roi align unittest and add fp16/serialization

* add trt roi align condition to op_teller

* refine error message

* remove unnecessary reshape layer
develop
zlsh80826 4 years ago committed by GitHub
parent bfb5cf5567
commit e3a38d790a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
USE_TRT_CONVERTER(roi_align);
USE_TRT_CONVERTER(affine_channel);
USE_TRT_CONVERTER(multiclass_nms);
USE_TRT_CONVERTER(nearest_interp);

@ -6,6 +6,7 @@ nv_library(tensorrt_converter
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
roi_align_op.cc
affine_channel_op.cc
multiclass_nms_op.cc
nearest_interp_op.cc

@ -0,0 +1,86 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Roi Align Op
*/
class RoiAlignOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid roi align op to tensorrt plugin";
framework::OpDesc op_desc(op, nullptr);
std::string input_name = op_desc.Input("X").front();
std::string rois_name = op_desc.Input("ROIs").front();
std::string output_name = op_desc.Output("Out").front();
const auto pooled_height =
BOOST_GET_CONST(int, op_desc.GetAttr("pooled_height"));
const auto pooled_width =
BOOST_GET_CONST(int, op_desc.GetAttr("pooled_width"));
const auto spatial_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("spatial_scale"));
const auto sampling_ratio =
BOOST_GET_CONST(int, op_desc.GetAttr("sampling_ratio"));
const auto input_tensor = engine_->GetITensor(input_name);
const auto rois_tensor = engine_->GetITensor(rois_name);
const nvinfer1::DataType data_type_ = engine_->WithFp16()
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT;
std::vector<nvinfer1::ITensor*> inputs{input_tensor, rois_tensor};
nvinfer1::ILayer* layer = nullptr;
PADDLE_ENFORCE_EQ(
engine_->with_dynamic_shape(), true,
platform::errors::InvalidArgument(
"TRT roi align plugin only accept the dynamic shape, because that "
"the roi_align will change the batch size."));
auto* roi_align_plugin = new plugin::RoiAlignPluginDynamic(
data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio);
auto roi_align_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *roi_align_plugin);
layer = roi_align_layer;
std::vector<std::string> output_names{output_name};
RreplenishLayerAndOutput(layer, "roi_align", output_names, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(roi_align, RoiAlignOpConverter);

@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"flatten2",
"flatten",
"gather",
"roi_align",
"affine_channel",
"multiclass_nms",
"nearest_interp",
@ -263,6 +264,29 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::string, desc.GetAttr("interp_method"));
if (interp_method != "nearest") return false;
}
if (op_type == "roi_align") {
if (!with_dynamic_shape) return false;
std::vector<std::string> attrs{"pooled_height", "pooled_width",
"spatial_scale", "sampling_ratio"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false;
}
const auto pooled_height =
BOOST_GET_CONST(int, desc.GetAttr("pooled_height"));
if (pooled_height <= 0) return false;
const auto pooled_width =
BOOST_GET_CONST(int, desc.GetAttr("pooled_width"));
if (pooled_width <= 0) return false;
const auto spatial_scale =
BOOST_GET_CONST(float, desc.GetAttr("spatial_scale"));
if (spatial_scale <= 0.f) return false;
}
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
}
return false;

@ -5,6 +5,7 @@ nv_library(tensorrt_plugin
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
roi_align_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS

@ -0,0 +1,112 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class RoiAlignPluginDynamic : public DynamicPluginTensorRT {
public:
explicit RoiAlignPluginDynamic(const nvinfer1::DataType data_type,
const int pooled_height,
const int pooled_width, float spatial_scale,
int sampling_ratio);
RoiAlignPluginDynamic(void const* data, size_t length);
~RoiAlignPluginDynamic() = default;
nvinfer1::IPluginV2DynamicExt* clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
const char* getPluginType() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void destroy() override;
private:
template <typename T, typename OutT>
int enqueue_impl(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream);
nvinfer1::DataType data_type_;
int pooled_height_;
int pooled_width_;
float spatial_scale_;
int sampling_ratio_;
int smem_per_block_;
std::string namespace_;
};
class RoiAlignPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
RoiAlignPluginDynamicCreator();
~RoiAlignPluginDynamicCreator() override = default;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2Ext* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override;
private:
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};
REGISTER_TRT_PLUGIN_V2(RoiAlignPluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle

@ -0,0 +1,119 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TRTRoiAlignTest(InferencePassTest):
def setUp(self):
self.bs = 2
self.num_rois = 4
self.channel = 16
self.height = 32
self.width = 32
self.precision = AnalysisConfig.Precision.Float32
self.serialize = False
self.enable_trt = True
def build(self):
self.trt_parameters = TRTRoiAlignTest.TensorRTParam(
1 << 30, self.bs * self.num_rois, 1, self.precision, self.serialize,
False)
with fluid.program_guard(self.main_program, self.startup_program):
data_shape = [-1, self.channel, self.height, self.width]
data = fluid.data(name='data', shape=data_shape, dtype='float32')
rois = fluid.data(
name='rois', shape=[-1, 4], dtype='float32', lod_level=1)
roi_align_out = fluid.layers.roi_align(data, rois)
out = fluid.layers.batch_norm(roi_align_out, is_test=True)
rois_lod = fluid.create_lod_tensor(
np.random.random([self.bs * self.num_rois, 4]).astype('float32'),
[[self.num_rois, self.num_rois]], fluid.CPUPlace())
data_shape[0] = self.bs
self.feeds = {
'data': np.random.random(data_shape).astype('float32'),
'rois': rois_lod,
}
self.fetch_list = [out]
def check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
atol = 1e-5
if self.trt_parameters.precision == AnalysisConfig.Precision.Half:
atol = 1e-3
self.check_output_with_option(use_gpu, atol, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
def set_dynamic(self):
min_shape_spec = dict()
max_shape_spec = dict()
opt_shape_spec = dict()
min_shape_spec['data'] = [
self.bs, self.channel, self.height // 2, self.width // 2
]
min_shape_spec['rois'] = [1, 4]
max_shape_spec[
'data'] = [self.bs, self.channel, self.height * 2, self.width * 2]
max_shape_spec['rois'] = [self.bs * self.num_rois, 4]
opt_shape_spec[
'data'] = [self.bs, self.channel, self.height, self.width]
opt_shape_spec['rois'] = [self.bs * self.num_rois, 4]
self.dynamic_shape_params = InferencePassTest.DynamicShapeParam(
min_shape_spec, max_shape_spec, opt_shape_spec, False)
def run_test(self):
self.build()
self.check_output()
def test_base(self):
self.run_test()
def test_fp16(self):
self.precision = AnalysisConfig.Precision.Half
self.run_test()
def test_serialize(self):
self.serialize = True
self.run_test()
def test_dynamic(self):
self.set_dynamic()
self.run_test()
def test_dynamic_fp16(self):
self.set_dynamic()
self.precision = AnalysisConfig.Precision.Half
self.run_test()
def test_dynamic_serialize(self):
self.set_dynamic()
self.serialize = True
self.run_test()
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save