sigmoid mul fusion

pull/7802/head
zhengjun10 4 years ago
parent 181cdab640
commit 95d9c39f9d

@ -185,6 +185,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/layer_norm_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc

@ -6,6 +6,6 @@ inceptionv3.mindir
googlenet.mindir
retinaface.mindir
mobilefacenet.mindir
efficientnet.mindir
# efficientnet.mindir
resnext50.mindir
ocr_mobilenetV2.mindir

@ -1 +1 @@
efficientnet.mindir
# efficientnet.mindir

@ -3,7 +3,7 @@
1 retinaface.mindir
1 mobilefacenet.mindir
1 ocr_mobilenetV2.mindir
2 efficientnet.mindir
# 2 efficientnet.mindir
3 gender_res_large_deploy
3 ml_ocr_detect_20200305
3 hiai_cv_focusShootOCRModel_07

@ -42,6 +42,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/quant_dtype_cast_fusion.cc
../optimizer/fusion/layer_norm_fusion.cc
../optimizer/fusion/batchmatmul_fusion.cc
../optimizer/fusion/sigmoid_mul_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc

@ -27,6 +27,7 @@
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include "tools/optimizer/fusion/layer_norm_fusion.h"
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include "tools/optimizer/graph/identity_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
@ -64,6 +65,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
pm->AddPass(std::make_shared<opt::LayerNormFusion>());
pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation,
schema::ActivationType_RELU));
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation,

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/ops/activation.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
namespace {
bool IsActivationNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Activation;
}
return false;
}
bool IsMulNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Mul;
}
return false;
}
} // namespace
const BaseRef SigmoidMulFusion::DefinePattern() const {
auto input_var = std::make_shared<Var>();
auto activation_var = std::make_shared<CondVar>(IsActivationNode);
auto mul_var = std::make_shared<CondVar>(IsMulNode);
auto activation_input = VectorRef({activation_var, input_var});
return VectorRef({mul_var, input_var, activation_input});
}
// x * sigmoid(x) ->swish(x)
const AnfNodePtr SigmoidMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
auto mul_cnode = node->cast<CNodePtr>();
MS_ASSERT(mul_cnode != nullptr);
auto activation_cnode = mul_cnode->input(2)->cast<CNodePtr>();
MS_ASSERT(activation_cnode != nullptr);
// activation must sigmoid
auto primitive = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(activation_cnode->input(0));
auto activation_prim = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitive);
if (activation_prim->GetType() != schema::ActivationType_SIGMOID) {
return nullptr;
}
activation_prim->SetType(schema::ActivationType_SWISH);
return activation_cnode;
}
} // namespace mindspore::opt

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_SIGMOID_MUL_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_SIGMOID_MUL_FUSION_H_
#include "backend/optimizer/common/optimizer.h"
#include "tools/converter/converter_context.h"
namespace mindspore {
namespace opt {
class SigmoidMulFusion : public PatternProcessPass {
public:
explicit SigmoidMulFusion(bool multigraph = true) : PatternProcessPass("sigmoid_mul_fusion", multigraph) {}
~SigmoidMulFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_SIGMOID_MUL_FUSION_H_
Loading…
Cancel
Save