[Paddle-TRT] nearest_interp op (#31626)
* nearest_interp op converter w/ dynamic/static * fix data_layout include * add trt nearest unit_test * add nearest_interp NHWC test * update trt nearest interp nhwc testcase * remove asterisk for python2 compatibility * add empty line to prevent conflict * nearest_interp op converter w/ dynamic/static * fix data_layout include * add trt nearest unit_test * add nearest_interp NHWC test * update trt nearest interp nhwc testcase * remove asterisk for python2 compatibility * add empty line to prevent conflict * change the priority of out_h, out_w2.0.1-rocm-post
parent
7ccf6b6030
commit
bfced39eb6
@ -0,0 +1,114 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/data_layout.h"
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class Scope;
|
||||
namespace proto {
|
||||
class OpDesc;
|
||||
} // namespace proto
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class NearestInterpolateOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(3) << "convert a fluid nearest_interp op";
|
||||
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
|
||||
std::string input_name = op_desc.Input("X").front();
|
||||
std::string output_name = op_desc.Output("Out").front();
|
||||
|
||||
auto input = engine_->GetITensor(input_name);
|
||||
|
||||
auto data_layout = framework::StringToDataLayout(
|
||||
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout")));
|
||||
auto interp_method =
|
||||
BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method"));
|
||||
bool align_corners =
|
||||
BOOST_GET_CONST(bool, op_desc.GetAttr("align_corners"));
|
||||
|
||||
auto input_names = op_desc.Input("X");
|
||||
auto scale = BOOST_GET_CONST(float, op_desc.GetAttr("scale"));
|
||||
auto out_h = BOOST_GET_CONST(int, op_desc.GetAttr("out_h"));
|
||||
auto out_w = BOOST_GET_CONST(int, op_desc.GetAttr("out_w"));
|
||||
|
||||
auto layer = TRT_ENGINE_ADD_LAYER(engine_, Resize, *input);
|
||||
layer->setAlignCorners(align_corners);
|
||||
|
||||
auto in_dim = input->getDimensions();
|
||||
|
||||
float scale_h = 1.f;
|
||||
float scale_w = 1.f;
|
||||
|
||||
std::vector<float> scales;
|
||||
|
||||
if (scale > 0.f && (out_h <= 0 && out_w <= 0)) {
|
||||
scale_h = scale;
|
||||
scale_w = scale;
|
||||
} else {
|
||||
// axis are different in static/dynamic mode
|
||||
PADDLE_ENFORCE_GT(
|
||||
out_h, 0, platform::errors::InvalidArgument(
|
||||
"out_h must be greater than 0 if scale is not set."));
|
||||
PADDLE_ENFORCE_GT(
|
||||
out_w, 0, platform::errors::InvalidArgument(
|
||||
"out_w must be greater than 0 if scale is not set."));
|
||||
|
||||
bool with_dynamic = engine_->with_dynamic_shape();
|
||||
|
||||
int h_axis = (data_layout == framework::DataLayout::kNCHW) + with_dynamic;
|
||||
int w_axis =
|
||||
(data_layout == framework::DataLayout::kNCHW) + 1 + with_dynamic;
|
||||
|
||||
scale_h =
|
||||
static_cast<float>(out_h) / static_cast<float>(in_dim.d[h_axis]);
|
||||
scale_w =
|
||||
static_cast<float>(out_w) / static_cast<float>(in_dim.d[w_axis]);
|
||||
}
|
||||
|
||||
if (engine_->with_dynamic_shape()) {
|
||||
scales.push_back(1.f);
|
||||
}
|
||||
|
||||
if (data_layout == framework::DataLayout::kNCHW) {
|
||||
scales.push_back(1.f);
|
||||
scales.push_back(scale_h);
|
||||
scales.push_back(scale_w);
|
||||
} else if (data_layout == framework::DataLayout::kNHWC) {
|
||||
// NHWC
|
||||
scales.push_back(scale_h);
|
||||
scales.push_back(scale_w);
|
||||
scales.push_back(1.f);
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"Data layout must be NCHW or NHWC."));
|
||||
}
|
||||
layer->setScales(scales.data(), scales.size());
|
||||
|
||||
RreplenishLayerAndOutput(layer, "nearest_interp", {output_name}, test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(nearest_interp, NearestInterpolateOpConverter);
|
@ -0,0 +1,192 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from inference_pass_test import InferencePassTest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.core import PassVersionChecker
|
||||
from paddle.fluid.core import AnalysisConfig
|
||||
|
||||
|
||||
class TRTNearestInterpTest(InferencePassTest):
|
||||
def setUp(self):
|
||||
self.set_params()
|
||||
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
if self.data_layout == 'NCHW':
|
||||
shape = [
|
||||
-1, self.channels, self.origin_shape[0],
|
||||
self.origin_shape[1]
|
||||
]
|
||||
else:
|
||||
shape = [
|
||||
-1, self.origin_shape[0], self.origin_shape[1],
|
||||
self.channels
|
||||
]
|
||||
data = fluid.data(name='data', shape=shape, dtype='float32')
|
||||
resize_out = self.append_nearest_interp(data)
|
||||
out = fluid.layers.batch_norm(resize_out, is_test=True)
|
||||
|
||||
if self.data_layout == 'NCHW':
|
||||
shape = [
|
||||
self.bs, self.channels, self.origin_shape[0],
|
||||
self.origin_shape[1]
|
||||
]
|
||||
else:
|
||||
shape = [
|
||||
self.bs, self.origin_shape[0], self.origin_shape[1],
|
||||
self.channels
|
||||
]
|
||||
|
||||
self.feeds = {'data': np.random.random(shape).astype('float32'), }
|
||||
self.enable_trt = True
|
||||
self.trt_parameters = TRTNearestInterpTest.TensorRTParam(
|
||||
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
|
||||
self.fetch_list = [out]
|
||||
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = 1
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (64, 64) # HW
|
||||
self.align_corners = True
|
||||
self.data_layout = 'NCHW'
|
||||
|
||||
def append_nearest_interp(self, data):
|
||||
if self.scale > 0.:
|
||||
return fluid.layers.resize_nearest(
|
||||
data,
|
||||
scale=self.scale,
|
||||
align_corners=self.align_corners,
|
||||
data_format=self.data_layout)
|
||||
return fluid.layers.resize_nearest(
|
||||
data,
|
||||
out_shape=self.resize_shape,
|
||||
align_corners=self.align_corners,
|
||||
data_format=self.data_layout)
|
||||
|
||||
def test_check_output(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
use_gpu = True
|
||||
self.check_output_with_option(use_gpu, flatten=True)
|
||||
self.assertTrue(
|
||||
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
|
||||
|
||||
|
||||
class TRTNearestInterpTest1(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = -1
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (64, 64) # HW
|
||||
self.align_corners = True
|
||||
self.data_layout = 'NCHW'
|
||||
|
||||
|
||||
class TRTNearestInterpTest2(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = 2.
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (64, 64) # HW
|
||||
self.align_corners = False
|
||||
self.data_layout = 'NCHW'
|
||||
|
||||
|
||||
class TRTNearestInterpTest3(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = -1
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (64, 64) # HW
|
||||
self.align_corners = False
|
||||
self.data_layout = 'NCHW'
|
||||
|
||||
|
||||
class TRTNearestInterpTest4(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = -1
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (47, 48) # HW
|
||||
self.align_corners = False
|
||||
self.data_layout = 'NCHW'
|
||||
|
||||
|
||||
class TRTNearestInterpTest5(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = -1
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (64, 64) # HW
|
||||
self.align_corners = True
|
||||
self.data_layout = 'NHWC'
|
||||
|
||||
|
||||
class TRTNearestInterpTest6(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = 2.
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (64, 64) # HW
|
||||
self.align_corners = False
|
||||
self.data_layout = 'NHWC'
|
||||
|
||||
|
||||
class TRTNearestInterpTest7(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = -1
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (64, 64) # HW
|
||||
self.align_corners = False
|
||||
self.data_layout = 'NHWC'
|
||||
|
||||
|
||||
class TRTNearestInterpTest8(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = -1
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (47, 48) # HW
|
||||
self.align_corners = False
|
||||
self.data_layout = 'NHWC'
|
||||
|
||||
|
||||
class TRTNearestInterpTest9(TRTNearestInterpTest):
|
||||
def set_params(self):
|
||||
self.bs = 4
|
||||
self.scale = -1
|
||||
self.channels = 3
|
||||
self.origin_shape = (32, 32) # HW
|
||||
self.resize_shape = (47, 48) # HW
|
||||
self.align_corners = False
|
||||
self.data_layout = 'NHWC'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue