[Paddle-TRT] yolobox (#31755)

* yolobox converter and plugin

* yolobox unittest

* add dynamic shape restriction

* fix git merge log
develop
zlsh80826 4 years ago committed by GitHub
parent c4b60efabd
commit 64ee255ffd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
USE_TRT_CONVERTER(yolo_box);
USE_TRT_CONVERTER(roi_align);
USE_TRT_CONVERTER(affine_channel);
USE_TRT_CONVERTER(multiclass_nms);

@ -6,6 +6,7 @@ nv_library(tensorrt_converter
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
yolo_box_op.cc
roi_align_op.cc
affine_channel_op.cc
multiclass_nms_op.cc

@ -0,0 +1,79 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class YoloBoxOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid yolo box op to tensorrt plugin";
framework::OpDesc op_desc(op, nullptr);
std::string X = op_desc.Input("X").front();
std::string img_size = op_desc.Input("ImgSize").front();
auto* X_tensor = engine_->GetITensor(X);
auto* img_size_tensor = engine_->GetITensor(img_size);
int class_num = BOOST_GET_CONST(int, op_desc.GetAttr("class_num"));
std::vector<int> anchors =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("anchors"));
int downsample_ratio =
BOOST_GET_CONST(int, op_desc.GetAttr("downsample_ratio"));
float conf_thresh = BOOST_GET_CONST(float, op_desc.GetAttr("conf_thresh"));
bool clip_bbox = BOOST_GET_CONST(bool, op_desc.GetAttr("clip_bbox"));
float scale_x_y = BOOST_GET_CONST(float, op_desc.GetAttr("scale_x_y"));
int type_id = static_cast<int>(engine_->WithFp16());
auto input_dim = X_tensor->getDimensions();
auto* yolo_box_plugin = new plugin::YoloBoxPlugin(
type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
anchors, class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y,
input_dim.d[1], input_dim.d[2]);
std::vector<nvinfer1::ITensor*> yolo_box_inputs;
yolo_box_inputs.push_back(X_tensor);
yolo_box_inputs.push_back(img_size_tensor);
auto* yolo_box_layer = engine_->network()->addPluginV2(
yolo_box_inputs.data(), yolo_box_inputs.size(), *yolo_box_plugin);
std::vector<std::string> output_names;
output_names.push_back(op_desc.Output("Boxes").front());
output_names.push_back(op_desc.Output("Scores").front());
RreplenishLayerAndOutput(yolo_box_layer, "yolo_box", output_names,
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(yolo_box, YoloBoxOpConverter);

@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"flatten2",
"flatten",
"gather",
"yolo_box",
"roi_align",
"affine_channel",
"multiclass_nms",
@ -198,6 +199,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false;
}
if (op_type == "yolo_box") {
if (with_dynamic_shape) return false;
bool has_attrs =
(desc.HasAttr("class_num") && desc.HasAttr("anchors") &&
desc.HasAttr("downsample_ratio") && desc.HasAttr("conf_thresh") &&
desc.HasAttr("clip_bbox") && desc.HasAttr("scale_x_y"));
return has_attrs;
}
if (op_type == "affine_channel") {
if (!desc.HasAttr("data_layout")) return false;
auto data_layout = framework::StringToDataLayout(

@ -5,6 +5,7 @@ nv_library(tensorrt_plugin
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
yolo_box_op_plugin.cu
roi_align_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)

@ -0,0 +1,117 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class YoloBoxPlugin : public nvinfer1::IPluginV2Ext {
public:
explicit YoloBoxPlugin(const nvinfer1::DataType data_type,
const std::vector<int>& anchors, const int class_num,
const float conf_thresh, const int downsample_ratio,
const bool clip_bbox, const float scale_x_y,
const int input_h, const int input_w);
YoloBoxPlugin(const void* data, size_t length);
~YoloBoxPlugin() override;
const char* getPluginType() const override;
const char* getPluginVersion() const override;
int getNbOutputs() const override;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
int nb_input_dims) override;
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::TensorFormat format) const override;
size_t getWorkspaceSize(int max_batch_size) const override;
int enqueue(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) override;
template <typename T>
int enqueue_impl(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream);
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void destroy() override;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_type,
int nb_inputs) const override;
bool isOutputBroadcastAcrossBatch(int output_index,
const bool* input_is_broadcast,
int nb_inputs) const override;
bool canBroadcastInputAcrossBatch(int input_index) const override;
void configurePlugin(const nvinfer1::Dims* input_dims, int nb_inputs,
const nvinfer1::Dims* output_dims, int nb_outputs,
const nvinfer1::DataType* input_types,
const nvinfer1::DataType* output_types,
const bool* input_is_broadcast,
const bool* output_is_broadcast,
nvinfer1::PluginFormat float_format,
int max_batct_size) override;
nvinfer1::IPluginV2Ext* clone() const override;
private:
nvinfer1::DataType data_type_;
std::vector<int> anchors_;
int* anchors_device_;
int class_num_;
float conf_thresh_;
int downsample_ratio_;
bool clip_bbox_;
float scale_x_y_;
int input_h_;
int input_w_;
std::string namespace_;
};
class YoloBoxPluginCreator : public nvinfer1::IPluginCreator {
public:
YoloBoxPluginCreator();
~YoloBoxPluginCreator() override = default;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2Ext* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override;
private:
std::string namespace_;
nvinfer1::PluginFieldCollection field_collection_;
};
REGISTER_TRT_PLUGIN_V2(YoloBoxPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle

@ -0,0 +1,76 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TRTYoloBoxTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
image_shape = [self.bs, self.channel, self.height, self.width]
image = fluid.data(name='image', shape=image_shape, dtype='float32')
image_size = fluid.data(
name='image_size', shape=[self.bs, 2], dtype='int32')
boxes, scores = self.append_yolobox(image, image_size)
scores = fluid.layers.reshape(scores, (self.bs, -1))
out = fluid.layers.batch_norm(scores, is_test=True)
self.feeds = {
'image': np.random.random(image_shape).astype('float32'),
'image_size': np.random.randint(
32, 64, size=(self.bs, 2)).astype('int32'),
}
self.enable_trt = True
self.trt_parameters = TRTYoloBoxTest.TensorRTParam(
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out, boxes]
def set_params(self):
self.bs = 4
self.channel = 255
self.height = 64
self.width = 64
self.class_num = 80
self.anchors = [10, 13, 16, 30, 33, 23]
self.conf_thresh = .1
self.downsample_ratio = 32
def append_yolobox(self, image, image_size):
return fluid.layers.yolo_box(
x=image,
img_size=image_size,
class_num=self.class_num,
anchors=self.anchors,
conf_thresh=self.conf_thresh,
downsample_ratio=self.downsample_ratio)
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save