|
|
|
@ -65,13 +65,21 @@ class HardSwishOpConverter : public OpConverter {
|
|
|
|
|
const float offset = op_desc.HasAttr("offset")
|
|
|
|
|
? BOOST_GET_CONST(float, op_desc.GetAttr("offset"))
|
|
|
|
|
: 3.0f;
|
|
|
|
|
|
|
|
|
|
nvinfer1::ILayer* layer = nullptr;
|
|
|
|
|
|
|
|
|
|
plugin::HardSwishPlugin* plugin =
|
|
|
|
|
new plugin::HardSwishPlugin(threshold, scale, offset);
|
|
|
|
|
layer = engine_->AddPlugin(&input, input_num, plugin);
|
|
|
|
|
|
|
|
|
|
if (threshold == scale) {
|
|
|
|
|
auto* hsig_layer = TRT_ENGINE_ADD_LAYER(
|
|
|
|
|
engine_, Activation, *input, nvinfer1::ActivationType::kHARD_SIGMOID);
|
|
|
|
|
hsig_layer->setAlpha(1.0 / scale);
|
|
|
|
|
hsig_layer->setBeta(offset / scale);
|
|
|
|
|
nvinfer1::IElementWiseLayer* eltwise_layer = TRT_ENGINE_ADD_LAYER(
|
|
|
|
|
engine_, ElementWise, *input, *(hsig_layer->getOutput(0)),
|
|
|
|
|
nvinfer1::ElementWiseOperation::kPROD);
|
|
|
|
|
layer = eltwise_layer;
|
|
|
|
|
} else {
|
|
|
|
|
plugin::HardSwishPlugin* plugin =
|
|
|
|
|
new plugin::HardSwishPlugin(threshold, scale, offset);
|
|
|
|
|
layer = engine_->AddPlugin(&input, input_num, plugin);
|
|
|
|
|
}
|
|
|
|
|
auto output_name = op_desc.Output("Out")[0];
|
|
|
|
|
RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode);
|
|
|
|
|
}
|
|
|
|
|