!13703 [lite]matmul and add fusion

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
pull/13703/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 69af6643d3

@ -45,7 +45,7 @@ int GatherFp16CPUKernel::Init() {
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
}
(reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ = *(reinterpret_cast<int *>(in_tensors_.at(2)->data_c()));
if (!InferShapeDone()) {
return RET_OK;
}

@ -15,7 +15,9 @@
*/
#include "src/runtime/kernel/npu/matmul_npu.h"
#include <memory>
#include "src/kernel_registry.h"
#include "src/runtime/agent/npu/npu_converter_utils.h"
using mindspore::kernel::KERNEL_ARCH::kNPU;
using mindspore::lite::KernelRegistrar;
@ -24,6 +26,11 @@ using mindspore::schema::PrimitiveType_MatMul;
namespace mindspore::kernel {
int MatMulNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (inputs.size() == 3) {
if (inputs[2]->shape().size() != 1) {
return RET_ERROR;
}
}
return RET_OK;
}
@ -33,7 +40,33 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con
op_->set_input_x1(*npu_inputs[0]);
op_->set_input_x2(*npu_inputs[1]);
if (npu_inputs.size() == 3) {
op_->set_input_bias(*npu_inputs[2]);
matmul_parameter_->has_bias_ = true;
add_op_ = new (std::nothrow) hiai::op::Add(name_ + "_add");
if (add_op_ == nullptr) {
MS_LOG(ERROR) << "new add op failed.";
return RET_ERROR;
}
add_op_->set_input_x1(*op_);
auto bias_shape = inputs[2]->shape();
auto bias_tensor = std::make_shared<ge::Tensor>();
if (bias_tensor == nullptr) {
MS_LOG(ERROR) << "new bias_tensor failed.";
return RET_ERROR;
}
ge::TensorDesc bias_tensor_desc(lite::ConverterToNPUShape({1, bias_shape[0], 1, 1}), ge::FORMAT_NCHW,
lite::ConverterToNPUDataType(inputs[2]->data_type()));
if (outputs[0]->shape().size() == 2) {
bias_tensor_desc.SetShape(lite::ConverterToNPUShape({1, bias_shape[0]}));
}
bias_tensor->SetTensorDesc(bias_tensor_desc);
bias_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size());
bias_ = new (std::nothrow) hiai::op::Const(name_ + "_bias");
if (bias_ == nullptr) {
MS_LOG(ERROR) << "new bias const failed.";
return RET_ERROR;
}
bias_->set_attr_value(bias_tensor);
add_op_->set_input_x2(*bias_);
}
op_->set_attr_transpose_x1(matmul_parameter_->a_transpose_);
@ -41,13 +74,26 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con
return RET_OK;
}
ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { return this->op_; }
ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() {
if (matmul_parameter_->has_bias_) {
return add_op_;
}
return op_;
}
MatMulNPUKernel::~MatMulNPUKernel() {
if (op_ != nullptr) {
delete op_;
op_ = nullptr;
}
if (add_op_ != nullptr) {
delete add_op_;
add_op_ = nullptr;
}
if (bias_ != nullptr) {
delete bias_;
bias_ = nullptr;
}
}
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator<MatMulNPUKernel>)

@ -39,6 +39,8 @@ class MatMulNPUKernel : public NPUKernel {
private:
hiai::op::MatMul *op_ = nullptr;
hiai::op::Add *add_op_ = nullptr;
hiai::op::Const *bias_ = nullptr;
MatMulParameter *matmul_parameter_;
};
} // namespace mindspore::kernel

@ -11,12 +11,12 @@ STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_C_FLAGS "$
STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
if(ENABLE_CONVERTER)
set(CCSRC_SRC
## ccsrc
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
)
set(CCSRC_SRC
## ccsrc
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
)
else()
set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc)
add_compile_definitions(USE_ANDROID_LOG)
@ -38,10 +38,10 @@ file(GLOB KERNEL_OP_SRC
file(GLOB KERNEL_OP_TRAIN_SRC
${LITE_DIR}/nnacl/fp32_grad/*.c
${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc
)
)
if(SUPPORT_TRAIN)
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
endif()
if(PLATFORM_ARM64)
# assembly
@ -114,9 +114,9 @@ if(SUPPORT_GPU STREQUAL vulkan)
endif()
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
if(ENABLE_CONVERTER)
set(BUILD_MINDDATA "off")
endif()
if(ENABLE_CONVERTER)
set(BUILD_MINDDATA "off")
endif()
endif()
### runtime framework
add_definitions(-DENABLE_V0)
@ -189,19 +189,19 @@ if(ENABLE_MINDRT)
include_directories(${CORE_DIR}/mindrt/)
include_directories(${CORE_DIR}/mindrt/src/)
set(TEST_LITE_SRC ${TEST_LITE_SRC}
${LITE_DIR}/src/lite_mindrt.cc
${LITE_DIR}/src/mindrt_executor.cc
${CORE_DIR}/mindrt/src/litebus.cc
${CORE_DIR}/mindrt/src/actor/actor.cc
${CORE_DIR}/mindrt/src/actor/actormgr.cc
${CORE_DIR}/mindrt/src/actor/actorpolicy.cc
${CORE_DIR}/mindrt/src/actor/actorthread.cc
${CORE_DIR}/mindrt/src/actor/aid.cc
${CORE_DIR}/mindrt/src/async/async.cc
${CORE_DIR}/mindrt/src/async/future.cc
${CORE_DIR}/mindrt/src/async/uuid_base.cc
${CORE_DIR}/mindrt/src/async/uuid_generator.cc
)
${LITE_DIR}/src/lite_mindrt.cc
${LITE_DIR}/src/mindrt_executor.cc
${CORE_DIR}/mindrt/src/litebus.cc
${CORE_DIR}/mindrt/src/actor/actor.cc
${CORE_DIR}/mindrt/src/actor/actormgr.cc
${CORE_DIR}/mindrt/src/actor/actorpolicy.cc
${CORE_DIR}/mindrt/src/actor/actorthread.cc
${CORE_DIR}/mindrt/src/actor/aid.cc
${CORE_DIR}/mindrt/src/async/async.cc
${CORE_DIR}/mindrt/src/async/future.cc
${CORE_DIR}/mindrt/src/async/uuid_base.cc
${CORE_DIR}/mindrt/src/async/uuid_generator.cc
)
endif()
@ -242,6 +242,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
@ -286,16 +287,16 @@ else()
endif()
### test src
file(GLOB_RECURSE TEST_CASE_KERNEL_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
${TEST_DIR}/ut/nnacl/infer/*.cc
)
${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
${TEST_DIR}/ut/nnacl/infer/*.cc
)
file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
)
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
)
set(TEST_SRC
${TEST_LITE_SRC}
@ -306,7 +307,7 @@ set(TEST_SRC
${TEST_DIR}/ut/src/infer_test.cc
${TEST_DIR}/ut/src/utils_test.cc
${TEST_DIR}/ut/src/scheduler_test.cc
)
)
if(ENABLE_CONVERTER)
set(TEST_SRC
@ -358,7 +359,7 @@ endif()
if(ENABLE_FP16 AND SUPPORT_TRAIN)
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc)
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc)
list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD})
endif()

@ -52,6 +52,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/tf_lstm_cell_fusion.cc
../optimizer/fusion/tf_bidirection_gru_fusion.cc
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
../optimizer/fusion/matmul_add_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc

@ -35,6 +35,7 @@
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h"
#include "tools/optimizer/fusion/matmul_add_fusion.h"
#include "tools/optimizer/graph/primitive_adjust_pass.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
@ -107,6 +108,9 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti
fusion_pm->AddPass(remove_unused_transpose_pass);
}
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
if (!config->trainModel) {
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
}
optimizer->AddPassManager(fusion_pm);
return RET_OK;
}

@ -1,48 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_gemm_parser.h"
#include <vector>
#include <memory>
#include "ops/make_tuple.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto prim = std::make_unique<ops::MakeTuple>();
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul");
if (node_parser == nullptr) {
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
return nullptr;
}
auto *matmul_primitive = node_parser->Parse(onnx_graph, onnx_node);
prim->AddAttr("MatMul", std::shared_ptr<ops::PrimitiveC>(matmul_primitive));
node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd");
if (node_parser == nullptr) {
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
return nullptr;
}
auto *bias_add_primitive = node_parser->Parse(onnx_graph, onnx_node);
prim->AddAttr("BiasAdd", std::shared_ptr<ops::PrimitiveC>(bias_add_primitive));
return prim.release();
}
OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser());
} // namespace lite
} // namespace mindspore

@ -1,34 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxGemmParser : public OnnxNodeParser {
public:
OnnxGemmParser() : OnnxNodeParser("Gemm") {}
~OnnxGemmParser() override = default;
ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H

@ -46,5 +46,6 @@ ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, con
}
OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser());
OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxMatmulParser());
} // namespace lite
} // namespace mindspore

@ -44,7 +44,6 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
std::set<std::string> SPECIAL_NODE = {"Gemm"};
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
NoSupportOp::GetInstance()->SetFmkType("ONNX");
@ -215,11 +214,6 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
continue;
}
if (IsSpecialOnnxNode(onnx_node)) {
auto status_node = ConvertSpecialOnnxNode(onnx_node, anf_graph, anf_nodes_map, primitive_c);
status = status == RET_OK ? status_node : status;
continue;
}
// build CNode
status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name);
if (status != RET_OK) {
@ -1023,117 +1017,6 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
return status;
}
STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
ops::PrimitiveC *primitive_c) {
if (primitive_c == nullptr || anf_graph == nullptr) {
MS_LOG(ERROR) << "imitive_c is nullptr.";
return RET_NULL_PTR;
}
STATUS status = RET_OK;
if (onnx_node.op_type() == "Gemm") {
status = ConvertOnnxGemmNode(onnx_node, anf_graph, anf_nodes_map, primitive_c);
} else {
MS_LOG(ERROR) << "the node is not special node.";
status = RET_ERROR;
}
delete primitive_c;
return status;
}
STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
ops::PrimitiveC *primitive_c) {
if (primitive_c == nullptr || anf_graph == nullptr) {
MS_LOG(ERROR) << "parameter has nullptr.";
return RET_NULL_PTR;
}
if (onnx_node.op_type() != "Gemm") {
MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type();
return RET_ERROR;
}
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr.";
return RET_NULL_PTR;
}
auto status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "MatMul");
if (status != RET_OK) {
MS_LOG(ERROR) << "convert gemm node failed.";
return status;
}
status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "BiasAdd");
if (status != RET_OK) {
MS_LOG(ERROR) << "convert gemm node failed.";
return status;
}
return RET_OK;
}
STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
ops::PrimitiveC *primitive_c, const std::string &name) {
if (primitive_c == nullptr || anf_graph == nullptr) {
MS_LOG(ERROR) << "parameter has nullptr.";
return RET_NULL_PTR;
}
auto value = primitive_c->GetAttr(name);
primitive_c->EraseAttr(name);
if (value == nullptr) {
MS_LOG(ERROR) << "op parse failed.";
return RET_NULL_PTR;
}
auto prim_ptr = value->cast<std::shared_ptr<ops::PrimitiveC>>();
if (prim_ptr == nullptr) {
MS_LOG(ERROR) << "primitive parse failed.";
return RET_NULL_PTR;
}
auto type_ptr = TypeIdToType(kTypeUnknown);
std::vector<int64_t> shape_vector;
std::vector<AnfNodePtr> op_inputs;
auto quant_params_holder = std::make_shared<QuantParamHolder>();
auto quant_params_holder_origin = primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
if (name == "MatMul") {
for (int i = 0; i < 2; ++i) {
if (anf_nodes_map->find(onnx_node.input(i)) == anf_nodes_map->end()) {
MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed.";
return RET_ERROR;
} else {
op_inputs.push_back(anf_nodes_map->at(onnx_node.input(i)));
quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(i));
}
}
quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1));
auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs);
if (new_cnode == nullptr) {
MS_LOG(ERROR) << "new cnode error";
return RET_ERROR;
}
new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0));
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
anf_nodes_map->emplace("Gemm_MatMul_" + onnx_node.output(0), new_cnode);
} else {
if (anf_nodes_map->find("Gemm_MatMul_" + onnx_node.output(0)) == anf_nodes_map->end() ||
anf_nodes_map->find(onnx_node.input(2)) == anf_nodes_map->end()) {
MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed.";
return RET_ERROR;
}
op_inputs.push_back(anf_nodes_map->at("Gemm_MatMul_" + onnx_node.output(0)));
op_inputs.push_back(anf_nodes_map->at(onnx_node.input(2)));
quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1));
quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(2));
quant_params_holder->AddOutputQuantParam(quant_params_holder_origin->output_quant_params().front());
auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs);
if (new_cnode == nullptr) {
MS_LOG(ERROR) << "new cnode error";
return RET_ERROR;
}
new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0));
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
anf_nodes_map->emplace(onnx_node.output(0), new_cnode);
}
return RET_OK;
}
STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) {
if (data == nullptr) {
MS_LOG(ERROR) << "value is nullptr.";
@ -1281,10 +1164,6 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_t
return RET_OK;
}
bool OnnxModelParser::IsSpecialOnnxNode(const onnx::NodeProto &onnx_node) {
return SPECIAL_NODE.find(onnx_node.op_type()) != SPECIAL_NODE.end();
}
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
auto iter = TYPE_MAP.find(onnx_type);
if (iter == TYPE_MAP.end()) {

@ -69,21 +69,11 @@ class OnnxModelParser : public ModelParser {
ops::PrimitiveC *primitive_c, std::string loop_name);
static STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode);
static STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
ops::PrimitiveC *primitive_c);
static STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
ops::PrimitiveC *primitive_c);
static STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
ops::PrimitiveC *primitive_c, const std::string &name);
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c);
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
static bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node);
STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
const std::string &root_node_name);

@ -0,0 +1,79 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/matmul_add_fusion.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t AddInputSize = 3;
constexpr size_t MatMulInputSize = 3;
bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(index != nullptr);
if (cnode->size() != AddInputSize) {
return false;
}
size_t matmul_index = 0;
for (size_t i = 1; i < cnode->size(); ++i) {
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) {
auto matmul_cnode = cnode->input(i)->cast<CNodePtr>();
if (matmul_cnode->size() > MatMulInputSize) {
continue;
}
matmul_index = i;
break;
}
}
if (matmul_index == 0) {
return false;
}
*index = matmul_index;
return true;
}
} // namespace
bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nulltr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) {
continue;
}
size_t index = 0;
if (!CheckAndGetMatMulIndex(cnode, &index)) {
continue;
}
auto matmul_cnode = cnode->input(index)->cast<CNodePtr>();
auto bias_node = cnode->input(AddInputSize - index);
if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) {
continue;
}
matmul_cnode->add_input(bias_node);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope());
manager->Replace(node, matmul_cnode);
}
return false;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_
#include "backend/optimizer/common/optimizer.h"
#include "tools/converter/converter_context.h"
#include "backend/optimizer/common/pass.h"
namespace mindspore {
namespace opt {
class MatMulAddFusion : public Pass {
public:
MatMulAddFusion() : Pass("matmul_add_fusion") {}
~MatMulAddFusion() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_
Loading…
Cancel
Save