!13597 refactor normfusion

From: @wangzhe128
Reviewed-by: @hangangqiang,@hangangqiang,@zhang_xue_tong
Signed-off-by: @hangangqiang
pull/13597/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8232628001

@ -121,17 +121,6 @@ int CastFp16CPUKernel::DoCast(int thread_id) {
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
return RET_ERROR;
}
} else if (input_data_type == kNumberTypeInt64) {
switch (output_data_type) {
case kNumberTypeFloat16:
Int64ToFloat32(reinterpret_cast<int64_t *>(input->MutableData()) + offset,
reinterpret_cast<float *>(output_data) + offset, data_num);
break;
default:
MS_LOG(ERROR) << "Unsupported output data type " << output_data_type;
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
return RET_ERROR;
@ -147,5 +136,4 @@ int CastFp16CPUKernel::Run() {
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_Cast, LiteKernelCreator<CastFp16CPUKernel>)
} // namespace mindspore::kernel

@ -233,8 +233,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_norm_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/onnx_layer_norm_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/norm_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc

@ -44,8 +44,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/conv_tuplegetitem_fusion.cc
../optimizer/fusion/constant_folding_fusion.cc
../optimizer/fusion/quant_dtype_cast_fusion.cc
../optimizer/fusion/tf_norm_fusion.cc
../optimizer/fusion/onnx_layer_norm_fusion.cc
../optimizer/fusion/norm_fusion.cc
../optimizer/fusion/batchmatmul_fusion.cc
../optimizer/fusion/sigmoid_mul_fusion.cc
../optimizer/fusion/conv_conv_fusion.cc

@ -27,8 +27,7 @@
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/fusion/tf_norm_fusion.h"
#include "tools/optimizer/fusion/onnx_layer_norm_fusion.h"
#include "tools/optimizer/fusion/norm_fusion.h"
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include "tools/optimizer/fusion/conv_conv_fusion.h"

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/tf_norm_fusion.h"
#include "tools/optimizer/fusion/norm_fusion.h"
#include <memory>
#include "ops/fusion/layer_norm_fusion.h"
#include "ops/fusion/reduce_fusion.h"
@ -27,7 +27,61 @@
namespace mindspore {
namespace opt {
namespace {
lite::STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
inline bool IsAddNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion);
}
return false;
}
inline bool IsSquaredDifferenceNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference);
}
return false;
}
inline bool IsRsqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt);
}
return false;
}
inline bool IsMulNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion);
}
return false;
}
inline bool IsSubNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion);
}
return false;
}
inline bool IsPowNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimPowFusion);
}
return false;
}
inline bool IsSqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqrt);
}
return false;
}
inline bool IsDivNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimDiv) ||
CheckPrimitiveType(utils::cast<AnfNodePtr>(n), std::make_shared<Primitive>("DivFusion"));
}
return false;
}
STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
MS_ASSERT(node != nullptr);
if (utils::isa<ParameterPtr>(n)) {
auto axes_param = utils::cast<ParameterPtr>(n);
@ -72,31 +126,9 @@ bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr
}
} // namespace
const BaseRef TfNormFusion::DefinePattern() const {
VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode);
VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref});
auto mul1 = std::make_shared<CondVar>(IsMulNode);
VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_});
auto add1 = std::make_shared<CondVar>(IsAddNode);
VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_});
auto rsqrt1 = std::make_shared<CondVar>(IsRsqrtNode);
VectorRef rsqrt1_ref = VectorRef({rsqrt1, add1_ref});
auto mul2 = std::make_shared<CondVar>(IsMulNode);
VectorRef mul2_ref = VectorRef({mul2, rsqrt1_ref, gamma_});
VectorRef mul1_ref = VectorRef({mul1, input_, mul2_ref});
auto mul3 = std::make_shared<CondVar>(IsMulNode);
VectorRef mul3_ref = VectorRef({mul3, mean1_ref, mul2_ref});
auto sub1 = std::make_shared<CondVar>(IsSubNode);
VectorRef sub1_ref = VectorRef({sub1, beta_, mul3_ref});
auto add2 = std::make_shared<CondVar>(IsAddNode);
VectorRef add2_ref = VectorRef({add2, mul1_ref, sub1_ref});
return add2_ref;
}
CNodePtr TfNormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const schema::PrimitiveType type, float epsilon, int begin_norm_axis,
int begin_params_axis) const {
CNodePtr NormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const schema::PrimitiveType type, float epsilon, int begin_norm_axis,
int begin_params_axis) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
auto norm_primitive = std::make_unique<schema::PrimitiveT>();
@ -128,9 +160,9 @@ CNodePtr TfNormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const Equi
return new_node;
}
bool TfNormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes,
const std::vector<int> &params_shape, schema::PrimitiveType *type,
int *begin_norm_axis, int *begin_params_axis) const {
bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes,
const std::vector<int> &params_shape, schema::PrimitiveType *type,
int *begin_norm_axis, int *begin_params_axis) const {
MS_ASSERT(input_node != nullptr);
MS_ASSERT(type != nullptr);
MS_ASSERT(begin_norm_axis != nullptr);
@ -183,8 +215,8 @@ bool TfNormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::ve
return true;
}
bool TfNormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon,
int *begin_norm_axis, int *begin_params_axis) const {
bool NormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, int *begin_norm_axis,
int *begin_params_axis) const {
MS_ASSERT(equiv != nullptr);
MS_ASSERT(epsilon != nullptr);
MS_ASSERT(type != nullptr);
@ -251,8 +283,8 @@ bool TfNormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *ty
return true;
}
const AnfNodePtr TfNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(equiv != nullptr);
@ -276,12 +308,48 @@ const AnfNodePtr TfNormFusion::Process(const FuncGraphPtr &func_graph, const Anf
norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
if (type == schema::PrimitiveType_LayerNormFusion) {
norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
MS_LOG(INFO) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
MS_LOG(DEBUG) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
} else if (type == schema::PrimitiveType_InstanceNorm) {
norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope());
MS_LOG(INFO) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
MS_LOG(DEBUG) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
}
return norm_cnode;
}
const BaseRef TfNormFusion::DefinePattern() const {
VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode);
VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref});
auto mul1 = std::make_shared<CondVar>(IsMulNode);
VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_});
auto add1 = std::make_shared<CondVar>(IsAddNode);
VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_});
auto rsqrt1 = std::make_shared<CondVar>(IsRsqrtNode);
VectorRef rsqrt1_ref = VectorRef({rsqrt1, add1_ref});
auto mul2 = std::make_shared<CondVar>(IsMulNode);
VectorRef mul2_ref = VectorRef({mul2, rsqrt1_ref, gamma_});
VectorRef mul1_ref = VectorRef({mul1, input_, mul2_ref});
auto mul3 = std::make_shared<CondVar>(IsMulNode);
VectorRef mul3_ref = VectorRef({mul3, mean1_ref, mul2_ref});
auto sub1 = std::make_shared<CondVar>(IsSubNode);
VectorRef sub1_ref = VectorRef({sub1, beta_, mul3_ref});
auto add2 = std::make_shared<CondVar>(IsAddNode);
VectorRef add2_ref = VectorRef({add2, mul1_ref, sub1_ref});
return add2_ref;
}
const BaseRef OnnxLayerNormFusion::DefinePattern() const {
VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
VectorRef sub1_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref});
VectorRef sub2_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref});
VectorRef pow_ref = VectorRef({std::make_shared<CondVar>(IsPowNode), sub2_ref, std::make_shared<Var>()});
VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
VectorRef add1_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mean2_ref, epsilon_});
VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSqrtNode), add1_ref});
VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsDivNode), sub1_ref, sqrt_ref});
VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsMulNode), gamma_, div_ref});
VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mul_ref, beta_});
return add2_ref;
}
} // namespace opt
} // namespace mindspore

@ -28,10 +28,10 @@
namespace mindspore {
namespace opt {
/// fuse layer_norm, instance_norm into one operator
class TfNormFusion : public PatternProcessPass {
/// fuse layer_norm or instance_norm into one operator
class NormFusion : public PatternProcessPass {
public:
explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true)
explicit NormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {
input_ = std::make_shared<Var>();
mean1_ = std::make_shared<Var>();
@ -43,8 +43,8 @@ class TfNormFusion : public PatternProcessPass {
epsilon_ = std::make_shared<Var>();
}
~TfNormFusion() override = default;
const BaseRef DefinePattern() const override;
~NormFusion() override = default;
virtual const BaseRef DefinePattern() const = 0;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
@ -67,40 +67,20 @@ class TfNormFusion : public PatternProcessPass {
VarPtr epsilon_ = nullptr;
};
inline bool IsAddNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion);
}
return false;
}
inline bool IsSquaredDifferenceNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference);
}
return false;
}
inline bool IsRsqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt);
}
return false;
}
/// fuse tf layer_norm or instance_norm into one operator
class TfNormFusion : public NormFusion {
public:
~TfNormFusion() override = default;
const BaseRef DefinePattern() const override;
};
inline bool IsMulNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion);
}
return false;
}
/// fuse onnx layer_norm into one operator
class OnnxLayerNormFusion : public NormFusion {
public:
~OnnxLayerNormFusion() override = default;
const BaseRef DefinePattern() const override;
};
inline bool IsSubNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion);
}
return false;
}
} // namespace opt
} // namespace mindspore

@ -1,37 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/onnx_layer_norm_fusion.h"
#include <memory>
#include "ops/rsqrt.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
const BaseRef OnnxLayerNormFusion::DefinePattern() const {
VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
VectorRef sub1_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref});
VectorRef sub2_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref});
VectorRef pow_ref = VectorRef({std::make_shared<CondVar>(IsPowNode), sub2_ref, std::make_shared<Var>()});
VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
VectorRef add1_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mean2_ref, epsilon_});
VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSqrtNode), add1_ref});
VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsDivNode), sub1_ref, sqrt_ref});
VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsMulNode), gamma_, div_ref});
VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mul_ref, beta_});
return add2_ref;
}
} // namespace opt
} // namespace mindspore

@ -1,60 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "tools/optimizer/fusion/tf_norm_fusion.h"
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
class OnnxLayerNormFusion : public TfNormFusion {
public:
explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true)
: TfNormFusion(name, multigraph) {}
~OnnxLayerNormFusion() override = default;
const BaseRef DefinePattern() const override;
};
inline bool IsPowNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimPowFusion);
}
return false;
}
inline bool IsSqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqrt);
}
return false;
}
inline bool IsDivNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimDiv) ||
CheckPrimitiveType(utils::cast<AnfNodePtr>(n), std::make_shared<Primitive>("DivFusion"));
}
return false;
}
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_LAYER_NORM_FUSION_H_
Loading…
Cancel
Save