Merge pull request #12761 from NHZlX/global_pooling_trt

Add support for global pooling for trt
revert-12469-sum_op_dim_fix
Zhaolong Xing 7 years ago committed by GitHub
commit 310708726b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter {
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
bool global_pooling = boost::get<bool>(op_desc.GetAttr("global_pooling"));
std::string pool_type = std::string pool_type =
boost::get<std::string>(op_desc.GetAttr("pooling_type")); boost::get<std::string>(op_desc.GetAttr("pooling_type"));
std::vector<int> ksize = std::vector<int> ksize =
@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter {
std::vector<int> paddings = std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings")); boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
const nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
if (global_pooling == true) {
nvinfer1::Dims input_shape = input1->getDimensions();
int nbDims = input_shape.nbDims;
nv_ksize.d[0] = input_shape.d[nbDims - 2];
nv_ksize.d[1] = input_shape.d[nbDims - 1];
}
const nvinfer1::DimsHW nv_strides(strides[0], strides[1]); const nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);

@ -20,7 +20,7 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
TEST(Pool2dOpConverter, main) { void test_pool2d(bool global_pooling) {
framework::Scope scope; framework::Scope scope;
std::unordered_set<std::string> parameters; std::unordered_set<std::string> parameters;
TRTConvertValidation validator(5, parameters, scope, 1 << 15); TRTConvertValidation validator(5, parameters, scope, 1 << 15);
@ -28,7 +28,10 @@ TEST(Pool2dOpConverter, main) {
// The ITensor's Dims should not contain the batch size. // The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W. // So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4)); validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4));
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); if (global_pooling)
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1));
else
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2));
// Prepare Op description // Prepare Op description
framework::OpDesc desc; framework::OpDesc desc;
@ -45,6 +48,7 @@ TEST(Pool2dOpConverter, main) {
desc.SetAttr("ksize", ksize); desc.SetAttr("ksize", ksize);
desc.SetAttr("strides", strides); desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings); desc.SetAttr("paddings", paddings);
desc.SetAttr("global_pooling", global_pooling);
LOG(INFO) << "set OP"; LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
@ -53,6 +57,10 @@ TEST(Pool2dOpConverter, main) {
validator.Execute(3); validator.Execute(3);
} }
TEST(Pool2dOpConverter, normal) { test_pool2d(false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle

Loading…
Cancel
Save