support clip op trt converter (#29411)

revert-31562-mean
Pei Yang 4 years ago committed by GitHub
parent 1dd7b97b66
commit f860de4af7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1100,6 +1100,7 @@ USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
#endif
namespace paddle_infer {

@ -307,7 +307,7 @@ class PD_INFER_DECL PaddlePredictor {
/// This will save the IO copy for transfering inputs and outputs to predictor
/// workspace
/// and get some performance improvement.
/// To use it, one should call the AnalysisConfig.SwitchUseFeedFetchOp(true)
/// To use it, one should call the AnalysisConfig.SwitchUseFeedFetchOp(false)
/// and then use the `GetInputTensor` and `GetOutputTensor`
/// to directly write or read the input/output tensors.
/// \return Whether the run is successful

@ -4,7 +4,7 @@ nv_library(tensorrt_converter
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS

@ -0,0 +1,63 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* ClipOp
*/
class ClipOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(5130)
VLOG(3) << "convert a paddle clip op to tensorrt IActivationLayer.";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
float min = BOOST_GET_CONST(float, op_desc.GetAttr("min"));
float max = BOOST_GET_CONST(float, op_desc.GetAttr("max"));
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input,
nvinfer1::ActivationType::kCLIP);
layer->setAlpha(min);
layer->setBeta(max);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "clip", {output_name}, test_mode);
#else
PADDLE_THROW(
platform::errors::Fatal("clip TRT converter is only supported on TRT "
"5.1.3.0 or higher version."));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(clip, ClipOpConverter);

@ -32,8 +32,10 @@ struct SimpleOpTypeSetTeller : public Teller {
#if IS_TRT_VERSION_GE(5130)
teller_set.insert("relu6");
teller_set.insert("hard_sigmoid");
teller_set.insert("clip");
int8_teller_set.insert("relu6");
int8_teller_set.insert("hard_sigmoid");
int8_teller_set.insert("clip");
#endif
#if IS_TRT_VERSION_GE(6000)
teller_set.insert("fused_embedding_eltwise_layernorm");
@ -132,8 +134,9 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
auto* var_desc = block->FindVar(var_name);
const auto shape = var_desc->GetShape();
if (shape.size() < 3) {
VLOG(1) << "matmul op dims < 3 not supported in tensorrt, but got dims "
<< shape.size() << ", so jump it.";
VLOG(1)
<< "matmul op dims < 3 not supported in tensorrt, but got dims "
<< shape.size() << ", so jump it.";
return false;
}
}

@ -343,6 +343,11 @@ class TensorRTSubgraphPassHardSigmoidTest(TensorRTSubgraphPassActivationTest):
return fluid.layers.hard_sigmoid(x)
class TensorRTSubgraphPassClipTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.clip(x, 0, 1)
class TensorRTSubgraphPassTanhTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.tanh(x)

Loading…
Cancel
Save