!13996 [lite]gelu fusion and fp16 slice

From: @xu_anyue
Reviewed-by: @hangangqiang,@HilbertDavid
Signed-off-by: @hangangqiang
pull/13996/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 27f7bfef50

@ -76,18 +76,12 @@ int SliceCPUKernel::SliceParallelRun(int thread_id) {
}
int SliceCPUKernel::Run() {
auto ret = PreProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "PreProcess fail!ret: " << ret;
return ret;
}
if (param_->size_[1] < op_parameter_->thread_num_) {
DoSliceNoParallel(in_tensors_.at(0)->data_c(), out_tensors_.at(0)->data_c(), param_,
lite::DataTypeSize(in_tensors_.at(0)->data_type()));
return RET_OK;
}
ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "slice launch fail!ret: " << ret;
return RET_ERROR;
@ -96,6 +90,5 @@ int SliceCPUKernel::Run() {
}
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,75 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/slice_fp16.h"
#include "src/kernel_registry.h"
#include "nnacl/base/slice_base.h"
#include "nnacl/fp16/cast_fp16.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SliceFusion;
namespace mindspore::kernel {
int SliceFp16Launch(void *cdata, int task_id) {
if (cdata == nullptr) {
MS_LOG(ERROR) << "Input cdata is nullptr!";
return RET_ERROR;
}
auto kernel = reinterpret_cast<SliceFp16CPUKernel *>(cdata);
return kernel->SliceFp16ParallelRun(task_id);
}
SliceFp16CPUKernel::~SliceFp16CPUKernel() {
if (input_data_ != nullptr) {
context_->allocator->Free(input_data_);
input_data_ = nullptr;
}
}
int SliceFp16CPUKernel::Init() {
auto input_tensor = in_tensors_.at(0);
if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data_c() != nullptr) {
input_data_ =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
}
return SliceCPUKernel::Init();
}
int SliceFp16CPUKernel::SliceFp16ParallelRun(int thread_id) {
void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_;
DoSlice(input_data, out_tensors_.at(0)->data_c(), param_, thread_id, lite::DataTypeSize(kNumberTypeFloat16));
return RET_OK;
}
int SliceFp16CPUKernel::Run() {
void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_;
if (param_->size_[1] < op_parameter_->thread_num_) {
DoSliceNoParallel(input_data, out_tensors_.at(0)->data_c(), param_, lite::DataTypeSize(kNumberTypeFloat16));
return RET_OK;
}
auto ret = ParallelLaunch(this->context_->thread_pool_, SliceFp16Launch, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "fp16 slice launch fail!ret: " << ret;
return RET_ERROR;
}
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator<SliceFp16CPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/slice_base.h"
namespace mindspore::kernel {
class SliceFp16CPUKernel : public SliceCPUKernel {
public:
SliceFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: SliceCPUKernel(parameter, inputs, outputs, ctx) {}
~SliceFp16CPUKernel() override;
int Init() override;
int Run() override;
int SliceFp16ParallelRun(int thread_id);
private:
float16_t *input_data_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_

@ -243,6 +243,9 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc

@ -98,7 +98,7 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
MS_LOG(ERROR) << "value node is invalid.";
return;
}
if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTuple) ||
if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) ||
opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) {
has_make_tuple = true;
for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) {
@ -372,7 +372,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
ret = RET_MEMORY_FAILED;
break;
}
if (opt::CheckPrimitiveType(cnode, opt::kPrimReturn)) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
node->name = mindspore::ops::kNameReturn;
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get());
if (ret != RET_OK) {

@ -53,6 +53,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/tf_bidirection_gru_fusion.cc
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
../optimizer/fusion/matmul_add_fusion.cc
../optimizer/fusion/gelu_fusion.cc
../optimizer/fusion/tf_gelu_fusion.cc
../optimizer/fusion/onnx_gelu_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc

@ -37,6 +37,8 @@
#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h"
#include "tools/optimizer/fusion/matmul_add_fusion.h"
#include "tools/optimizer/graph/primitive_adjust_pass.h"
#include "tools/optimizer/fusion/tf_gelu_fusion.h"
#include "tools/optimizer/fusion/onnx_gelu_fusion.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
@ -89,6 +91,8 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti
fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
}
if (config->fmk == lite::converter::FmkType_MS) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();

@ -54,10 +54,10 @@ bool IsRealKernel(const AnfNodePtr &node) {
auto input = cnode->inputs()[0];
bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
IsPrimitive(input, prim::kPrimTensorSummary) ||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
IsPrimitive(input, kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
return !is_virtual_node;
}
@ -335,7 +335,7 @@ bool IsRealCNodeKernel(const AnfNodePtr &node) {
return false;
}
// return considered as a real node
if (CheckPrimitiveType(node, kPrimReturn)) {
if (CheckPrimitiveType(node, prim::kPrimReturn)) {
return true;
}
return IsRealKernel(node);

@ -35,8 +35,8 @@ using mindspore::lite::RET_OK;
using mindspore::lite::STATUS;
namespace mindspore {
namespace opt {
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return");
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("MakeTuple");
inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion");
inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf");
inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity");
std::vector<int> CastToInt(const ValuePtr &value);
@ -145,6 +145,15 @@ ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const st
ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
const std::string &node_name);
template <const PrimitivePtr *prim = nullptr>
inline bool IsSpecifiedNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);
return CheckPrimitiveType(anf_node, *prim);
}
return false;
}
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_

@ -0,0 +1,85 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/gelu_fusion.h"
#include <memory>
#include <string>
#include "ops/fusion/activation.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
CNodePtr GeLUFusion::CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
auto gelu_prim = std::make_shared<ops::Activation>();
gelu_prim->set_activation_type(mindspore::GELU);
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
MS_ASSERT(input_node != nullptr);
auto gelu_cnode = func_graph->NewCNode(gelu_prim, {input_node});
gelu_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_gelu");
gelu_cnode->set_abstract(node->abstract()->Clone());
return gelu_cnode;
}
const float GeLUFusion::GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const {
MS_ASSERT(equiv != nullptr);
MS_ASSERT(input != nullptr);
float value = -1;
auto node = utils::cast<AnfNodePtr>((*equiv)[input]);
if (node == nullptr || !utils::isa<ParameterPtr>(node)) {
return value;
}
auto parameter_node = node->cast<ParameterPtr>();
if (!parameter_node->has_default() || parameter_node->default_param() == nullptr) {
return value;
}
auto param_value_lite = parameter_node->default_param()->cast<ParamValueLitePtr>();
if (param_value_lite == nullptr) {
return value;
}
if (param_value_lite->tensor_type() != kNumberTypeFloat32 && param_value_lite->tensor_type() != kNumberTypeFloat) {
return value;
}
if (param_value_lite->tensor_size() != sizeof(float)) {
return value;
}
return *static_cast<float *>(param_value_lite->tensor_addr());
}
const AnfNodePtr GeLUFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(equiv != nullptr);
MS_LOG(DEBUG) << "gelu_fusion pass";
if (!utils::isa<CNodePtr>(node)) {
return nullptr;
}
if (!CheckPattern(equiv)) {
return nullptr;
}
auto cnode = CreateGeLUNode(func_graph, node, equiv);
if (cnode == nullptr) {
MS_LOG(DEBUG) << "new gelu node failed.";
return nullptr;
}
return cnode;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
class GeLUFusion : public PatternProcessPass {
public:
explicit GeLUFusion(const std::string &name = "gelu_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph), input_(std::make_shared<Var>()) {}
~GeLUFusion() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
protected:
virtual bool CheckPattern(const EquivPtr &equiv) const = 0;
const float GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const;
VarPtr input_ = nullptr;
private:
CNodePtr CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_

@ -20,19 +20,19 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t AddInputSize = 3;
constexpr size_t MatMulInputSize = 3;
constexpr size_t kAddInputSize = 3;
constexpr size_t kMatMulInputSize = 3;
bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(index != nullptr);
if (cnode->size() != AddInputSize) {
if (cnode->size() != kAddInputSize) {
return false;
}
size_t matmul_index = 0;
for (size_t i = 1; i < cnode->size(); ++i) {
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) {
auto matmul_cnode = cnode->input(i)->cast<CNodePtr>();
if (matmul_cnode->size() > MatMulInputSize) {
if (matmul_cnode->size() > kMatMulInputSize) {
continue;
}
matmul_index = i;
@ -63,7 +63,7 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
continue;
}
auto matmul_cnode = cnode->input(index)->cast<CNodePtr>();
auto bias_node = cnode->input(AddInputSize - index);
auto bias_node = cnode->input(kAddInputSize - index);
if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) {
continue;
}

@ -17,7 +17,6 @@
#include <memory>
#include "ops/fusion/layer_norm_fusion.h"
#include "ops/fusion/reduce_fusion.h"
#include "ops/rsqrt.h"
#include "mindspore/core/ops/instance_norm.h"
#include "src/param_value_lite.h"
#include "utils/utils.h"
@ -27,60 +26,6 @@
namespace mindspore {
namespace opt {
namespace {
inline bool IsAddNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimAddFusion);
}
return false;
}
inline bool IsSquaredDifferenceNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSquaredDifference);
}
return false;
}
inline bool IsRsqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimRsqrt);
}
return false;
}
inline bool IsMulNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimMulFusion);
}
return false;
}
inline bool IsSubNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSubFusion);
}
return false;
}
inline bool IsPowNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimPowFusion);
}
return false;
}
inline bool IsSqrtNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqrt);
}
return false;
}
inline bool IsDivNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimDiv) ||
CheckPrimitiveType(utils::cast<AnfNodePtr>(n), std::make_shared<Primitive>("DivFusion"));
}
return false;
}
STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
MS_ASSERT(node != nullptr);
if (utils::isa<ParameterPtr>(n)) {
@ -195,7 +140,7 @@ bool NormFusion::GetNormTypeAndAxis(const CNodePtr &input_cnode, const std::vect
}
}
if (mean_axes.back() >= 0 && mean_axes.back() + 1 != static_cast<int>(shape.size())) {
MS_LOG(DEBUG) << "mean node is not reduce to last axis";
MS_LOG(DEBUG) << "mean node is not reduce to last axis.";
return false;
}
@ -318,37 +263,41 @@ const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNo
const BaseRef TfNormFusion::DefinePattern() const {
VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
auto squared_diffference1 = std::make_shared<CondVar>(IsSquaredDifferenceNode);
auto squared_diffference1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquaredDifference>);
VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref});
auto mul1 = std::make_shared<CondVar>(IsMulNode);
auto mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_});
auto add1 = std::make_shared<CondVar>(IsAddNode);
auto add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_});
auto rsqrt1 = std::make_shared<CondVar>(IsRsqrtNode);
auto rsqrt1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRsqrt>);
VectorRef rsqrt1_ref = VectorRef({rsqrt1, add1_ref});
auto mul2 = std::make_shared<CondVar>(IsMulNode);
auto mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
VectorRef mul2_ref = VectorRef({mul2, rsqrt1_ref, gamma_});
VectorRef mul1_ref = VectorRef({mul1, input_, mul2_ref});
auto mul3 = std::make_shared<CondVar>(IsMulNode);
auto mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
VectorRef mul3_ref = VectorRef({mul3, mean1_ref, mul2_ref});
auto sub1 = std::make_shared<CondVar>(IsSubNode);
auto sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
VectorRef sub1_ref = VectorRef({sub1, beta_, mul3_ref});
auto add2 = std::make_shared<CondVar>(IsAddNode);
auto add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
VectorRef add2_ref = VectorRef({add2, mul1_ref, sub1_ref});
return add2_ref;
}
const BaseRef OnnxLayerNormFusion::DefinePattern() const {
VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
VectorRef sub1_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref});
VectorRef sub2_ref = VectorRef({std::make_shared<CondVar>(IsSubNode), input_, mean1_ref});
VectorRef pow_ref = VectorRef({std::make_shared<CondVar>(IsPowNode), sub2_ref, std::make_shared<Var>()});
VectorRef sub1_ref =
VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>), input_, mean1_ref});
VectorRef sub2_ref =
VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>), input_, mean1_ref});
VectorRef pow_ref =
VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>), sub2_ref, std::make_shared<Var>()});
VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
VectorRef add1_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mean2_ref, epsilon_});
VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSqrtNode), add1_ref});
VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsDivNode), sub1_ref, sqrt_ref});
VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsMulNode), gamma_, div_ref});
VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsAddNode), mul_ref, beta_});
VectorRef add1_ref =
VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mean2_ref, epsilon_});
VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>), add1_ref});
VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), sub1_ref, sqrt_ref});
VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), gamma_, div_ref});
VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mul_ref, beta_});
return add2_ref;
}
} // namespace opt

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_
#include <vector>
#include <memory>
@ -31,7 +31,7 @@ namespace opt {
/// fuse layer_norm or instance_norm into one operator
class NormFusion : public PatternProcessPass {
public:
explicit NormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true)
explicit NormFusion(const std::string &name = "norm_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {
input_ = std::make_shared<Var>();
mean1_ = std::make_shared<Var>();
@ -44,7 +44,6 @@ class NormFusion : public PatternProcessPass {
}
~NormFusion() override = default;
virtual const BaseRef DefinePattern() const = 0;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
@ -70,6 +69,9 @@ class NormFusion : public PatternProcessPass {
/// fuse tf layer_norm or instance_norm into one operator
class TfNormFusion : public NormFusion {
public:
explicit TfNormFusion(const std::string &name = "tf_norm_fusion", bool multigraph = true)
: NormFusion(name, multigraph) {}
~TfNormFusion() override = default;
const BaseRef DefinePattern() const override;
};
@ -77,11 +79,13 @@ class TfNormFusion : public NormFusion {
/// fuse onnx layer_norm into one operator
class OnnxLayerNormFusion : public NormFusion {
public:
explicit OnnxLayerNormFusion(const std::string &name = "onnx_layer_norm_fusion", bool multigraph = true)
: NormFusion(name, multigraph) {}
~OnnxLayerNormFusion() override = default;
const BaseRef DefinePattern() const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_LAYER_NORM_FUSION_H_
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/onnx_gelu_fusion.h"
namespace mindspore {
namespace opt {
namespace {
constexpr float DIFF_THRESHOLD = 0.0001;
constexpr float DIV_Y = 1.41421;
constexpr float ADD_Y = 1.0;
constexpr float MUL1_y = 0.5;
} // namespace
// gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))]
const BaseRef OnnxGeLUFusion::DefinePattern() const {
VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_});
VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimErf>), div_ref});
VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_});
VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_});
VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref});
return mul2_ref;
}
bool OnnxGeLUFusion::CheckPattern(const EquivPtr &equiv) const {
MS_ASSERT(equiv != nullptr);
float div_y = GetParameterValue(equiv, div_y_);
if (div_y < 0 || fabs(div_y - DIV_Y) > DIFF_THRESHOLD) {
return false;
}
float add_y = GetParameterValue(equiv, add_y_);
if (add_y < 0 || fabs(add_y - ADD_Y) > DIFF_THRESHOLD) {
return false;
}
float mul1_y = GetParameterValue(equiv, mul1_y_);
if (mul1_y < 0 || fabs(mul1_y - MUL1_y) > DIFF_THRESHOLD) {
return false;
}
return true;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "tools/optimizer/fusion/gelu_fusion.h"
namespace mindspore {
namespace opt {
class OnnxGeLUFusion : public GeLUFusion {
public:
explicit OnnxGeLUFusion(const std::string &name = "onnx_gelu_fusion", bool multigraph = true)
: GeLUFusion(name, multigraph) {
div_y_ = std::make_shared<Var>();
add_y_ = std::make_shared<Var>();
mul1_y_ = std::make_shared<Var>();
}
~OnnxGeLUFusion() override = default;
private:
bool CheckPattern(const EquivPtr &equiv) const override;
const BaseRef DefinePattern() const override;
private:
VarPtr div_y_ = nullptr;
VarPtr add_y_ = nullptr;
VarPtr mul1_y_ = nullptr;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_

@ -133,7 +133,7 @@ AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr
auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2});
VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4});
VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
@ -183,13 +183,13 @@ AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr
VectorRef select_hidden = VectorRef({std::make_shared<Var>("Switch"), greater_equal, placeholders[4], new_hidden});
auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple));
auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
std::vector<BaseRef> outputs = {is_make_tuple, add1, placeholders[1], add,
output, select_hidden, placeholders[5], placeholders[6],
placeholders[7]};
outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end());
VectorRef make_tuple_node = VectorRef(outputs);
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
VectorRef return_node = VectorRef({is_return, make_tuple_node});
VarPtr fg = std::make_shared<Var>("RootG");

@ -0,0 +1,88 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/tf_gelu_fusion.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr float DIFF_THRESHOLD = 0.0001;
constexpr float POW_Y = 3;
constexpr float MUL1_Y = 0.044715;
constexpr float MUL2_X = 0.79788;
constexpr float ADD2_X = 1.0;
constexpr float MUL3_X = 0.5;
bool CheckTanh(const EquivPtr &equiv, const VarPtr &input) {
MS_ASSERT(equiv != nullptr);
MS_ASSERT(input != nullptr);
auto anf_node = utils::cast<AnfNodePtr>((*equiv)[input]);
MS_ASSERT(anf_node != nullptr);
AnfNodePtr value_node = anf_node;
if (anf_node->isa<CNode>()) {
value_node = anf_node->cast<CNodePtr>()->input(0);
}
auto act_prim = GetValueNode<PrimitivePtr>(value_node);
if (act_prim == nullptr) {
return false;
}
return act_prim->GetAttr(ops::kActivationType) != nullptr &&
GetValue<int64_t>(act_prim->GetAttr(ops::kActivationType)) == mindspore::TANH;
}
} // namespace
// gelu(x) = 1/2 * x * [1 + tanh(0.79788 * (x + 0.044715 * x ^ 3))]
const BaseRef TfGeLUFusion::DefinePattern() const {
VectorRef pow_ref({power_, input_, power_y_});
VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_x_, pow_ref});
VectorRef add1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), input_, mul1_ref});
VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul2_x_, add1_ref});
VectorRef tanh_ref({tanh_, mul2_ref});
VectorRef add2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), add2_x_, tanh_ref});
VectorRef mul3_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul3_x_, add2_ref});
VectorRef mul4_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul3_ref});
return mul4_ref;
}
bool TfGeLUFusion::CheckPattern(const EquivPtr &equiv) const {
MS_ASSERT(equiv != nullptr);
if (!CheckTanh(equiv, tanh_)) {
return false;
}
float pow_y = GetParameterValue(equiv, power_y_);
if (pow_y < 0 || fabs(pow_y - POW_Y) > DIFF_THRESHOLD) {
return false;
}
float mul1_y = GetParameterValue(equiv, mul1_x_);
if (mul1_y < 0 || fabs(mul1_y - MUL1_Y) > DIFF_THRESHOLD) {
return false;
}
float mul2_x = GetParameterValue(equiv, mul2_x_);
if (mul2_x < 0 || fabs(mul2_x - MUL2_X) > DIFF_THRESHOLD) {
return false;
}
float add2_x = GetParameterValue(equiv, add2_x_);
if (add2_x < 0 || fabs(add2_x - ADD2_X) > DIFF_THRESHOLD) {
return false;
}
float mul3_x = GetParameterValue(equiv, mul3_x_);
if (mul3_x < 0 || fabs(mul3_x - MUL3_X) > DIFF_THRESHOLD) {
return false;
}
return true;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "tools/optimizer/fusion/gelu_fusion.h"
namespace mindspore {
namespace opt {
class TfGeLUFusion : public GeLUFusion {
public:
explicit TfGeLUFusion(const std::string &name = "tf_gelu_fusion", bool multigraph = true)
: GeLUFusion(name, multigraph) {
power_ = std::make_shared<Var>();
power_y_ = std::make_shared<Var>();
mul1_x_ = std::make_shared<Var>();
mul2_x_ = std::make_shared<Var>();
tanh_ = std::make_shared<Var>();
add2_x_ = std::make_shared<Var>();
mul3_x_ = std::make_shared<Var>();
}
~TfGeLUFusion() override = default;
private:
bool CheckPattern(const EquivPtr &equiv) const override;
const BaseRef DefinePattern() const override;
private:
VarPtr power_ = nullptr;
VarPtr power_y_ = nullptr;
VarPtr mul1_x_ = nullptr;
VarPtr mul2_x_ = nullptr;
VarPtr tanh_ = nullptr;
VarPtr add2_x_ = nullptr;
VarPtr mul3_x_ = nullptr;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_

@ -98,11 +98,11 @@ AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primi
VectorRef set_item = VectorRef({std::make_shared<Var>(""), placeholders[3], placeholders[2], new_hidden});
auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple));
auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden};
outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
VectorRef make_tuple_node = VectorRef(outputs);
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
VectorRef return_node = VectorRef({is_return, make_tuple_node});
VarPtr fg = std::make_shared<Var>("RootG");

@ -116,7 +116,7 @@ AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &p
auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2});
VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4});
VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
@ -174,11 +174,11 @@ AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &p
VectorRef set_item = VectorRef({std::make_shared<Var>("SetItem"), placeholders[3], placeholders[2], output});
auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimMakeTuple));
auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output};
outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
VectorRef make_tuple_node = VectorRef(outputs);
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, kPrimReturn));
auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
VectorRef return_node = VectorRef({is_return, make_tuple_node});
VarPtr fg = std::make_shared<Var>("RootG");

@ -41,8 +41,8 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) {
bool IsSpecialType(const CNodePtr &cnode) {
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, kPrimMakeTuple) ||
CheckPrimitiveType(cnode, kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) ||
CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) ||
CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) ||
CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) {
return true;
}

@ -81,7 +81,7 @@ bool WhilePass::Run(const FuncGraphPtr &graph) {
// concat body to cond
std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode};
if (CheckPrimitiveType(body_output_cnode, kPrimMakeTuple)) {
if (CheckPrimitiveType(body_output_cnode, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) {
body_to_cond_inputs.emplace_back(body_output_cnode->input(i));
}

Loading…
Cancel
Save