[Paddle-TRT] yolobox (#31755)
* yolobox converter and plugin * yolobox unittest * add dynamic shape restriction * fix git merge logdevelop
parent
c4b60efabd
commit
64ee255ffd
@ -0,0 +1,79 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class Scope;
|
||||
namespace proto {
|
||||
class OpDesc;
|
||||
} // namespace proto
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class YoloBoxOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(3) << "convert a fluid yolo box op to tensorrt plugin";
|
||||
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
std::string X = op_desc.Input("X").front();
|
||||
std::string img_size = op_desc.Input("ImgSize").front();
|
||||
|
||||
auto* X_tensor = engine_->GetITensor(X);
|
||||
auto* img_size_tensor = engine_->GetITensor(img_size);
|
||||
|
||||
int class_num = BOOST_GET_CONST(int, op_desc.GetAttr("class_num"));
|
||||
std::vector<int> anchors =
|
||||
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("anchors"));
|
||||
|
||||
int downsample_ratio =
|
||||
BOOST_GET_CONST(int, op_desc.GetAttr("downsample_ratio"));
|
||||
float conf_thresh = BOOST_GET_CONST(float, op_desc.GetAttr("conf_thresh"));
|
||||
bool clip_bbox = BOOST_GET_CONST(bool, op_desc.GetAttr("clip_bbox"));
|
||||
float scale_x_y = BOOST_GET_CONST(float, op_desc.GetAttr("scale_x_y"));
|
||||
|
||||
int type_id = static_cast<int>(engine_->WithFp16());
|
||||
auto input_dim = X_tensor->getDimensions();
|
||||
auto* yolo_box_plugin = new plugin::YoloBoxPlugin(
|
||||
type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
|
||||
anchors, class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y,
|
||||
input_dim.d[1], input_dim.d[2]);
|
||||
|
||||
std::vector<nvinfer1::ITensor*> yolo_box_inputs;
|
||||
yolo_box_inputs.push_back(X_tensor);
|
||||
yolo_box_inputs.push_back(img_size_tensor);
|
||||
|
||||
auto* yolo_box_layer = engine_->network()->addPluginV2(
|
||||
yolo_box_inputs.data(), yolo_box_inputs.size(), *yolo_box_plugin);
|
||||
|
||||
std::vector<std::string> output_names;
|
||||
output_names.push_back(op_desc.Output("Boxes").front());
|
||||
output_names.push_back(op_desc.Output("Scores").front());
|
||||
|
||||
RreplenishLayerAndOutput(yolo_box_layer, "yolo_box", output_names,
|
||||
test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(yolo_box, YoloBoxOpConverter);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,117 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/inference/tensorrt/engine.h"
|
||||
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
namespace plugin {
|
||||
|
||||
class YoloBoxPlugin : public nvinfer1::IPluginV2Ext {
|
||||
public:
|
||||
explicit YoloBoxPlugin(const nvinfer1::DataType data_type,
|
||||
const std::vector<int>& anchors, const int class_num,
|
||||
const float conf_thresh, const int downsample_ratio,
|
||||
const bool clip_bbox, const float scale_x_y,
|
||||
const int input_h, const int input_w);
|
||||
YoloBoxPlugin(const void* data, size_t length);
|
||||
~YoloBoxPlugin() override;
|
||||
|
||||
const char* getPluginType() const override;
|
||||
const char* getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
|
||||
int nb_input_dims) override;
|
||||
bool supportsFormat(nvinfer1::DataType type,
|
||||
nvinfer1::TensorFormat format) const override;
|
||||
size_t getWorkspaceSize(int max_batch_size) const override;
|
||||
int enqueue(int batch_size, const void* const* inputs, void** outputs,
|
||||
void* workspace, cudaStream_t stream) override;
|
||||
template <typename T>
|
||||
int enqueue_impl(int batch_size, const void* const* inputs, void** outputs,
|
||||
void* workspace, cudaStream_t stream);
|
||||
int initialize() override;
|
||||
void terminate() override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void* buffer) const override;
|
||||
void destroy() override;
|
||||
void setPluginNamespace(const char* lib_namespace) override;
|
||||
const char* getPluginNamespace() const override;
|
||||
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType* input_type,
|
||||
int nb_inputs) const override;
|
||||
bool isOutputBroadcastAcrossBatch(int output_index,
|
||||
const bool* input_is_broadcast,
|
||||
int nb_inputs) const override;
|
||||
bool canBroadcastInputAcrossBatch(int input_index) const override;
|
||||
void configurePlugin(const nvinfer1::Dims* input_dims, int nb_inputs,
|
||||
const nvinfer1::Dims* output_dims, int nb_outputs,
|
||||
const nvinfer1::DataType* input_types,
|
||||
const nvinfer1::DataType* output_types,
|
||||
const bool* input_is_broadcast,
|
||||
const bool* output_is_broadcast,
|
||||
nvinfer1::PluginFormat float_format,
|
||||
int max_batct_size) override;
|
||||
nvinfer1::IPluginV2Ext* clone() const override;
|
||||
|
||||
private:
|
||||
nvinfer1::DataType data_type_;
|
||||
std::vector<int> anchors_;
|
||||
int* anchors_device_;
|
||||
int class_num_;
|
||||
float conf_thresh_;
|
||||
int downsample_ratio_;
|
||||
bool clip_bbox_;
|
||||
float scale_x_y_;
|
||||
int input_h_;
|
||||
int input_w_;
|
||||
std::string namespace_;
|
||||
};
|
||||
|
||||
class YoloBoxPluginCreator : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
YoloBoxPluginCreator();
|
||||
~YoloBoxPluginCreator() override = default;
|
||||
|
||||
void setPluginNamespace(const char* lib_namespace) override;
|
||||
const char* getPluginNamespace() const override;
|
||||
const char* getPluginName() const override;
|
||||
const char* getPluginVersion() const override;
|
||||
const nvinfer1::PluginFieldCollection* getFieldNames() override;
|
||||
|
||||
nvinfer1::IPluginV2Ext* createPlugin(
|
||||
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
|
||||
nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
|
||||
const void* serial_data,
|
||||
size_t serial_length) override;
|
||||
|
||||
private:
|
||||
std::string namespace_;
|
||||
nvinfer1::PluginFieldCollection field_collection_;
|
||||
};
|
||||
|
||||
REGISTER_TRT_PLUGIN_V2(YoloBoxPluginCreator);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from inference_pass_test import InferencePassTest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.core import PassVersionChecker
|
||||
from paddle.fluid.core import AnalysisConfig
|
||||
|
||||
|
||||
class TRTYoloBoxTest(InferencePassTest):
|
||||
def setUp(self):
|
||||
self.set_params()
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
image_shape = [self.bs, self.channel, self.height, self.width]
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
image_size = fluid.data(
|
||||
name='image_size', shape=[self.bs, 2], dtype='int32')
|
||||
boxes, scores = self.append_yolobox(image, image_size)
|
||||
scores = fluid.layers.reshape(scores, (self.bs, -1))
|
||||
out = fluid.layers.batch_norm(scores, is_test=True)
|
||||
|
||||
self.feeds = {
|
||||
'image': np.random.random(image_shape).astype('float32'),
|
||||
'image_size': np.random.randint(
|
||||
32, 64, size=(self.bs, 2)).astype('int32'),
|
||||
}
|
||||
self.enable_trt = True
|
||||
self.trt_parameters = TRTYoloBoxTest.TensorRTParam(
|
||||
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
|
||||
self.fetch_list = [out, boxes]
|
||||
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.channel = 255
|
||||
self.height = 64
|
||||
self.width = 64
|
||||
self.class_num = 80
|
||||
self.anchors = [10, 13, 16, 30, 33, 23]
|
||||
self.conf_thresh = .1
|
||||
self.downsample_ratio = 32
|
||||
|
||||
def append_yolobox(self, image, image_size):
|
||||
return fluid.layers.yolo_box(
|
||||
x=image,
|
||||
img_size=image_size,
|
||||
class_num=self.class_num,
|
||||
anchors=self.anchors,
|
||||
conf_thresh=self.conf_thresh,
|
||||
downsample_ratio=self.downsample_ratio)
|
||||
|
||||
def test_check_output(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
use_gpu = True
|
||||
self.check_output_with_option(use_gpu, flatten=True)
|
||||
self.assertTrue(
|
||||
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue