[Paddle-TRT] support group_norm (#31040)
* add group norm plugin * fix compile problems * move concat axis check to trt op teller * add nbDims for scale and bias nv dims * add group norm unit test * fix unittest * add trt version restriction for group norm op teller * fix unittestrevert-31068-fix_conv3d_windows
parent
c209751c8d
commit
00b09e86ac
@ -0,0 +1,122 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class Scope;
|
||||
namespace proto {
|
||||
class OpDesc;
|
||||
} // namespace proto
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace tensorrt {
|
||||
|
||||
class GroupNormOpConverter : public OpConverter {
|
||||
public:
|
||||
void operator()(const framework::proto::OpDesc& op,
|
||||
const framework::Scope& scope, bool test_mode) override {
|
||||
VLOG(3) << "convert a fluid group_norm op";
|
||||
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
|
||||
auto* input_itensor = engine_->GetITensor(op_desc.Input("X").front());
|
||||
|
||||
int groups = BOOST_GET_CONST(int, op_desc.GetAttr("groups"));
|
||||
float epsilon = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
|
||||
|
||||
std::string scale_name = op_desc.Input("Scale").front();
|
||||
std::string bias_name = op_desc.Input("Bias").front();
|
||||
|
||||
// get the presistable var's data
|
||||
auto get_persistable_data = [&](const std::string& var_name,
|
||||
framework::DDim* dims) -> float* {
|
||||
auto* temp_var = scope.FindVar(var_name);
|
||||
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
|
||||
(*dims) = temp_tensor->dims();
|
||||
|
||||
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
|
||||
return temp_data;
|
||||
};
|
||||
|
||||
framework::DDim scale_dims;
|
||||
framework::DDim bias_dims;
|
||||
float* scale_data = get_persistable_data(scale_name, &scale_dims);
|
||||
float* bias_data = get_persistable_data(bias_name, &bias_dims);
|
||||
|
||||
int64_t scale_numel = framework::product(scale_dims);
|
||||
int64_t bias_numel = framework::product(bias_dims);
|
||||
|
||||
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
|
||||
static_cast<void*>(scale_data),
|
||||
static_cast<size_t>(scale_numel)};
|
||||
TensorRTEngine::Weight bias_weights{nvinfer1::DataType::kFLOAT,
|
||||
static_cast<void*>(bias_data),
|
||||
static_cast<size_t>(bias_numel)};
|
||||
|
||||
nvinfer1::Dims scale_nv_dims;
|
||||
nvinfer1::Dims bias_nv_dims;
|
||||
scale_nv_dims.nbDims = scale_dims.size();
|
||||
bias_nv_dims.nbDims = bias_dims.size();
|
||||
for (int i = 0; i < scale_dims.size(); i++) {
|
||||
scale_nv_dims.d[i] = scale_dims.at(i);
|
||||
}
|
||||
for (int i = 0; i < bias_dims.size(); i++) {
|
||||
bias_nv_dims.d[i] = bias_dims.at(i);
|
||||
}
|
||||
|
||||
auto* scale_layer = TRT_ENGINE_ADD_LAYER(engine_, Constant, scale_nv_dims,
|
||||
scale_weights.get());
|
||||
auto* bias_layer = TRT_ENGINE_ADD_LAYER(engine_, Constant, bias_nv_dims,
|
||||
bias_weights.get());
|
||||
|
||||
std::vector<nvinfer1::ITensor*> plugin_inputs;
|
||||
plugin_inputs.emplace_back(input_itensor);
|
||||
plugin_inputs.emplace_back(scale_layer->getOutput(0));
|
||||
plugin_inputs.emplace_back(bias_layer->getOutput(0));
|
||||
|
||||
const std::vector<nvinfer1::PluginField> fields{
|
||||
{"eps", &epsilon, nvinfer1::PluginFieldType::kFLOAT32, 1},
|
||||
{"num_groups", &groups, nvinfer1::PluginFieldType::kINT32, 1},
|
||||
};
|
||||
|
||||
nvinfer1::PluginFieldCollection* plugin_collections =
|
||||
static_cast<nvinfer1::PluginFieldCollection*>(
|
||||
malloc(sizeof(*plugin_collections) +
|
||||
fields.size() * sizeof(nvinfer1::PluginField)));
|
||||
plugin_collections->nbFields = static_cast<int>(fields.size());
|
||||
plugin_collections->fields = fields.data();
|
||||
|
||||
auto creator =
|
||||
GetPluginRegistry()->getPluginCreator("GroupNormalizationPlugin", "1");
|
||||
auto group_norm_plugin =
|
||||
creator->createPlugin("GroupNormalizationPlugin", plugin_collections);
|
||||
free(plugin_collections);
|
||||
|
||||
auto group_norm_plugin_layer = engine_->network()->addPluginV2(
|
||||
plugin_inputs.data(), plugin_inputs.size(), *group_norm_plugin);
|
||||
|
||||
auto output_name = op_desc.Output("Y")[0];
|
||||
RreplenishLayerAndOutput(group_norm_plugin_layer, "group_norm",
|
||||
{output_name}, test_mode);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_TRT_OP_CONVERTER(group_norm, GroupNormOpConverter);
|
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from inference_pass_test import InferencePassTest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.core import PassVersionChecker
|
||||
from paddle.fluid.core import AnalysisConfig
|
||||
|
||||
|
||||
class TRTGroupNormTest(InferencePassTest):
|
||||
def setUp(self):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
data = fluid.data(
|
||||
name="data", shape=[-1, 512, 12, 12], dtype="float32")
|
||||
relu_out = fluid.layers.relu(data)
|
||||
relu6_out = fluid.layers.relu6(relu_out)
|
||||
tanh_out = fluid.layers.tanh(relu6_out)
|
||||
conv_out = fluid.layers.conv2d(
|
||||
input=tanh_out,
|
||||
num_filters=512,
|
||||
filter_size=3,
|
||||
groups=1,
|
||||
padding=[1, 1],
|
||||
bias_attr=False,
|
||||
act=None)
|
||||
out = self.append_group_norm(conv_out)
|
||||
|
||||
self.feeds = {
|
||||
"data": np.random.random([1, 512, 12, 12]).astype("float32"),
|
||||
}
|
||||
self.enable_trt = True
|
||||
self.trt_parameters = TRTGroupNormTest.TensorRTParam(
|
||||
1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False)
|
||||
self.dynamic_shape_params = TRTGroupNormTest.DynamicShapeParam({
|
||||
'data': [1, 512, 12, 12]
|
||||
}, {'data': [1, 512, 12, 12]}, {'data': [1, 512, 12, 12]}, False)
|
||||
self.fetch_list = [out]
|
||||
|
||||
def append_group_norm(self, data):
|
||||
param_attr = fluid.ParamAttr(
|
||||
name='group_norm_scale',
|
||||
initializer=fluid.initializer.Constant(value=1.0))
|
||||
bias_attr = fluid.ParamAttr(
|
||||
name='group_norm_bias',
|
||||
initializer=fluid.initializer.Constant(value=0.0))
|
||||
return fluid.layers.group_norm(
|
||||
data,
|
||||
groups=32,
|
||||
epsilon=0.000009999999747378752,
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
|
||||
def test_check_output(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
use_gpu = True
|
||||
self.check_output_with_option(use_gpu)
|
||||
self.assertTrue(
|
||||
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue