diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc index 6d7e0bab37..8e40a1bccb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc @@ -45,7 +45,7 @@ int GatherFp16CPUKernel::Init() { reinterpret_cast(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); Float32ToFloat16(reinterpret_cast(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); } - + (reinterpret_cast(op_parameter_))->axis_ = *(reinterpret_cast(in_tensors_.at(2)->data_c())); if (!InferShapeDone()) { return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/npu/matmul_npu.cc b/mindspore/lite/src/runtime/kernel/npu/matmul_npu.cc index 08a9de491f..63e6006c48 100644 --- a/mindspore/lite/src/runtime/kernel/npu/matmul_npu.cc +++ b/mindspore/lite/src/runtime/kernel/npu/matmul_npu.cc @@ -15,7 +15,9 @@ */ #include "src/runtime/kernel/npu/matmul_npu.h" +#include #include "src/kernel_registry.h" +#include "src/runtime/agent/npu/npu_converter_utils.h" using mindspore::kernel::KERNEL_ARCH::kNPU; using mindspore::lite::KernelRegistrar; @@ -24,6 +26,11 @@ using mindspore::schema::PrimitiveType_MatMul; namespace mindspore::kernel { int MatMulNPUKernel::IsSupport(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter) { + if (inputs.size() == 3) { + if (inputs[2]->shape().size() != 1) { + return RET_ERROR; + } + } return RET_OK; } @@ -33,7 +40,33 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector &inputs, con op_->set_input_x1(*npu_inputs[0]); op_->set_input_x2(*npu_inputs[1]); if (npu_inputs.size() == 3) { - op_->set_input_bias(*npu_inputs[2]); + matmul_parameter_->has_bias_ = true; + add_op_ = new (std::nothrow) hiai::op::Add(name_ + "_add"); + if (add_op_ == nullptr) { + MS_LOG(ERROR) << "new add op failed."; + return RET_ERROR; + } + add_op_->set_input_x1(*op_); + auto bias_shape = inputs[2]->shape(); + auto bias_tensor = std::make_shared(); + if (bias_tensor == nullptr) { + MS_LOG(ERROR) << "new bias_tensor failed."; + return RET_ERROR; + } + ge::TensorDesc bias_tensor_desc(lite::ConverterToNPUShape({1, bias_shape[0], 1, 1}), ge::FORMAT_NCHW, + lite::ConverterToNPUDataType(inputs[2]->data_type())); + if (outputs[0]->shape().size() == 2) { + bias_tensor_desc.SetShape(lite::ConverterToNPUShape({1, bias_shape[0]})); + } + bias_tensor->SetTensorDesc(bias_tensor_desc); + bias_tensor->SetData(reinterpret_cast(inputs[2]->data_c()), inputs[2]->Size()); + bias_ = new (std::nothrow) hiai::op::Const(name_ + "_bias"); + if (bias_ == nullptr) { + MS_LOG(ERROR) << "new bias const failed."; + return RET_ERROR; + } + bias_->set_attr_value(bias_tensor); + add_op_->set_input_x2(*bias_); } op_->set_attr_transpose_x1(matmul_parameter_->a_transpose_); @@ -41,13 +74,26 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector &inputs, con return RET_OK; } -ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { return this->op_; } +ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { + if (matmul_parameter_->has_bias_) { + return add_op_; + } + return op_; +} MatMulNPUKernel::~MatMulNPUKernel() { if (op_ != nullptr) { delete op_; op_ = nullptr; } + if (add_op_ != nullptr) { + delete add_op_; + add_op_ = nullptr; + } + if (bias_ != nullptr) { + delete bias_; + bias_ = nullptr; + } } REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/npu/matmul_npu.h b/mindspore/lite/src/runtime/kernel/npu/matmul_npu.h index 4b54d9e293..23dfb4ae04 100644 --- a/mindspore/lite/src/runtime/kernel/npu/matmul_npu.h +++ b/mindspore/lite/src/runtime/kernel/npu/matmul_npu.h @@ -39,6 +39,8 @@ class MatMulNPUKernel : public NPUKernel { private: hiai::op::MatMul *op_ = nullptr; + hiai::op::Add *add_op_ = nullptr; + hiai::op::Const *bias_ = nullptr; MatMulParameter *matmul_parameter_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index ea67f57c84..9a9374f141 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -11,12 +11,12 @@ STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_C_FLAGS "$ STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") if(ENABLE_CONVERTER) -set(CCSRC_SRC - ## ccsrc - ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc - ${CCSRC_DIR}/backend/optimizer/common/visit.cc - ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc - ) + set(CCSRC_SRC + ## ccsrc + ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc + ${CCSRC_DIR}/backend/optimizer/common/visit.cc + ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc + ) else() set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc) add_compile_definitions(USE_ANDROID_LOG) @@ -38,10 +38,10 @@ file(GLOB KERNEL_OP_SRC file(GLOB KERNEL_OP_TRAIN_SRC ${LITE_DIR}/nnacl/fp32_grad/*.c ${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc -) + ) if(SUPPORT_TRAIN) - list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) + list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) endif() if(PLATFORM_ARM64) # assembly @@ -114,9 +114,9 @@ if(SUPPORT_GPU STREQUAL vulkan) endif() if(PLATFORM_ARM32 OR PLATFORM_ARM64) - if(ENABLE_CONVERTER) - set(BUILD_MINDDATA "off") - endif() + if(ENABLE_CONVERTER) + set(BUILD_MINDDATA "off") + endif() endif() ### runtime framework add_definitions(-DENABLE_V0) @@ -189,19 +189,19 @@ if(ENABLE_MINDRT) include_directories(${CORE_DIR}/mindrt/) include_directories(${CORE_DIR}/mindrt/src/) set(TEST_LITE_SRC ${TEST_LITE_SRC} - ${LITE_DIR}/src/lite_mindrt.cc - ${LITE_DIR}/src/mindrt_executor.cc - ${CORE_DIR}/mindrt/src/litebus.cc - ${CORE_DIR}/mindrt/src/actor/actor.cc - ${CORE_DIR}/mindrt/src/actor/actormgr.cc - ${CORE_DIR}/mindrt/src/actor/actorpolicy.cc - ${CORE_DIR}/mindrt/src/actor/actorthread.cc - ${CORE_DIR}/mindrt/src/actor/aid.cc - ${CORE_DIR}/mindrt/src/async/async.cc - ${CORE_DIR}/mindrt/src/async/future.cc - ${CORE_DIR}/mindrt/src/async/uuid_base.cc - ${CORE_DIR}/mindrt/src/async/uuid_generator.cc - ) + ${LITE_DIR}/src/lite_mindrt.cc + ${LITE_DIR}/src/mindrt_executor.cc + ${CORE_DIR}/mindrt/src/litebus.cc + ${CORE_DIR}/mindrt/src/actor/actor.cc + ${CORE_DIR}/mindrt/src/actor/actormgr.cc + ${CORE_DIR}/mindrt/src/actor/actorpolicy.cc + ${CORE_DIR}/mindrt/src/actor/actorthread.cc + ${CORE_DIR}/mindrt/src/actor/aid.cc + ${CORE_DIR}/mindrt/src/async/async.cc + ${CORE_DIR}/mindrt/src/async/future.cc + ${CORE_DIR}/mindrt/src/async/uuid_base.cc + ${CORE_DIR}/mindrt/src/async/uuid_generator.cc + ) endif() @@ -242,6 +242,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc @@ -286,16 +287,16 @@ else() endif() ### test src file(GLOB_RECURSE TEST_CASE_KERNEL_SRC - ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc - ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc - ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc - ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc - ${TEST_DIR}/ut/nnacl/infer/*.cc -) + ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc + ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc + ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc + ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc + ${TEST_DIR}/ut/nnacl/infer/*.cc + ) file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC - ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc -) + ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc + ) set(TEST_SRC ${TEST_LITE_SRC} @@ -306,7 +307,7 @@ set(TEST_SRC ${TEST_DIR}/ut/src/infer_test.cc ${TEST_DIR}/ut/src/utils_test.cc ${TEST_DIR}/ut/src/scheduler_test.cc -) + ) if(ENABLE_CONVERTER) set(TEST_SRC @@ -358,7 +359,7 @@ endif() if(ENABLE_FP16 AND SUPPORT_TRAIN) file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD - ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc) + ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc) list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD}) endif() diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 41a19824ea..70f5a8fe87 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -52,6 +52,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/tf_lstm_cell_fusion.cc ../optimizer/fusion/tf_bidirection_gru_fusion.cc ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc + ../optimizer/fusion/matmul_add_fusion.cc ../optimizer/graph/weight_format_transform_pass.cc ../optimizer/graph/weight_format_hardcode_pass.cc ../optimizer/graph/clip_convert_activation_pass.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 42d0545bea..28acec5da5 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -35,6 +35,7 @@ #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" +#include "tools/optimizer/fusion/matmul_add_fusion.h" #include "tools/optimizer/graph/primitive_adjust_pass.h" #include "tools/optimizer/graph/mindir_adjust_pass.h" #include "tools/optimizer/graph/redundant_op_remove_pass.h" @@ -107,6 +108,9 @@ int AnfTransform::AddFusionPass(const std::shared_ptr &opti fusion_pm->AddPass(remove_unused_transpose_pass); } fusion_pm->AddPass(std::make_shared()); + if (!config->trainModel) { + fusion_pm->AddPass(std::make_shared()); + } optimizer->AddPassManager(fusion_pm); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc deleted file mode 100644 index 82e7e2b719..0000000000 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/parser/onnx/onnx_gemm_parser.h" -#include -#include -#include "ops/make_tuple.h" - -namespace mindspore { -namespace lite { -ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { - auto prim = std::make_unique(); - - auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul"); - if (node_parser == nullptr) { - MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; - return nullptr; - } - auto *matmul_primitive = node_parser->Parse(onnx_graph, onnx_node); - prim->AddAttr("MatMul", std::shared_ptr(matmul_primitive)); - - node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd"); - if (node_parser == nullptr) { - MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; - return nullptr; - } - auto *bias_add_primitive = node_parser->Parse(onnx_graph, onnx_node); - prim->AddAttr("BiasAdd", std::shared_ptr(bias_add_primitive)); - - return prim.release(); -} - -OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser()); -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.h deleted file mode 100644 index 948deca088..0000000000 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_gemm_parser.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H - -#include "tools/converter/parser/onnx/onnx_node_parser.h" -#include "tools/converter/parser/onnx/onnx_node_parser_registry.h" - -namespace mindspore { -namespace lite { -class OnnxGemmParser : public OnnxNodeParser { - public: - OnnxGemmParser() : OnnxNodeParser("Gemm") {} - ~OnnxGemmParser() override = default; - - ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; -}; -} // namespace lite -} // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc index dbb798bf1a..02747c6e57 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc @@ -46,5 +46,6 @@ ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, con } OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); +OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxMatmulParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 4986f123f9..a39985a3d8 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -44,7 +44,6 @@ static const std::unordered_map TYPE_MAP = { {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; -std::set SPECIAL_NODE = {"Gemm"}; FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) { NoSupportOp::GetInstance()->SetFmkType("ONNX"); @@ -215,11 +214,6 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; continue; } - if (IsSpecialOnnxNode(onnx_node)) { - auto status_node = ConvertSpecialOnnxNode(onnx_node, anf_graph, anf_nodes_map, primitive_c); - status = status == RET_OK ? status_node : status; - continue; - } // build CNode status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name); if (status != RET_OK) { @@ -1023,117 +1017,6 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf return status; } -STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, - std::unordered_map *anf_nodes_map, - ops::PrimitiveC *primitive_c) { - if (primitive_c == nullptr || anf_graph == nullptr) { - MS_LOG(ERROR) << "imitive_c is nullptr."; - return RET_NULL_PTR; - } - STATUS status = RET_OK; - if (onnx_node.op_type() == "Gemm") { - status = ConvertOnnxGemmNode(onnx_node, anf_graph, anf_nodes_map, primitive_c); - } else { - MS_LOG(ERROR) << "the node is not special node."; - status = RET_ERROR; - } - delete primitive_c; - return status; -} - -STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, - std::unordered_map *anf_nodes_map, - ops::PrimitiveC *primitive_c) { - if (primitive_c == nullptr || anf_graph == nullptr) { - MS_LOG(ERROR) << "parameter has nullptr."; - return RET_NULL_PTR; - } - if (onnx_node.op_type() != "Gemm") { - MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type(); - return RET_ERROR; - } - if (primitive_c == nullptr) { - MS_LOG(ERROR) << "primitive_c is nullptr."; - return RET_NULL_PTR; - } - auto status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "MatMul"); - if (status != RET_OK) { - MS_LOG(ERROR) << "convert gemm node failed."; - return status; - } - status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "BiasAdd"); - if (status != RET_OK) { - MS_LOG(ERROR) << "convert gemm node failed."; - return status; - } - return RET_OK; -} - -STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, - std::unordered_map *anf_nodes_map, - ops::PrimitiveC *primitive_c, const std::string &name) { - if (primitive_c == nullptr || anf_graph == nullptr) { - MS_LOG(ERROR) << "parameter has nullptr."; - return RET_NULL_PTR; - } - auto value = primitive_c->GetAttr(name); - primitive_c->EraseAttr(name); - if (value == nullptr) { - MS_LOG(ERROR) << "op parse failed."; - return RET_NULL_PTR; - } - auto prim_ptr = value->cast>(); - if (prim_ptr == nullptr) { - MS_LOG(ERROR) << "primitive parse failed."; - return RET_NULL_PTR; - } - auto type_ptr = TypeIdToType(kTypeUnknown); - std::vector shape_vector; - std::vector op_inputs; - auto quant_params_holder = std::make_shared(); - auto quant_params_holder_origin = primitive_c->GetAttr("quant_params")->cast(); - if (name == "MatMul") { - for (int i = 0; i < 2; ++i) { - if (anf_nodes_map->find(onnx_node.input(i)) == anf_nodes_map->end()) { - MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; - return RET_ERROR; - } else { - op_inputs.push_back(anf_nodes_map->at(onnx_node.input(i))); - quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(i)); - } - } - quant_params_holder->AddOutputQuantParam(std::vector(1)); - auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs); - if (new_cnode == nullptr) { - MS_LOG(ERROR) << "new cnode error"; - return RET_ERROR; - } - new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); - new_cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); - anf_nodes_map->emplace("Gemm_MatMul_" + onnx_node.output(0), new_cnode); - } else { - if (anf_nodes_map->find("Gemm_MatMul_" + onnx_node.output(0)) == anf_nodes_map->end() || - anf_nodes_map->find(onnx_node.input(2)) == anf_nodes_map->end()) { - MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; - return RET_ERROR; - } - op_inputs.push_back(anf_nodes_map->at("Gemm_MatMul_" + onnx_node.output(0))); - op_inputs.push_back(anf_nodes_map->at(onnx_node.input(2))); - quant_params_holder->AddInputQuantParam(std::vector(1)); - quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(2)); - quant_params_holder->AddOutputQuantParam(quant_params_holder_origin->output_quant_params().front()); - auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs); - if (new_cnode == nullptr) { - MS_LOG(ERROR) << "new cnode error"; - return RET_ERROR; - } - new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); - new_cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); - anf_nodes_map->emplace(onnx_node.output(0), new_cnode); - } - return RET_OK; -} - STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) { if (data == nullptr) { MS_LOG(ERROR) << "value is nullptr."; @@ -1281,10 +1164,6 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_t return RET_OK; } -bool OnnxModelParser::IsSpecialOnnxNode(const onnx::NodeProto &onnx_node) { - return SPECIAL_NODE.find(onnx_node.op_type()) != SPECIAL_NODE.end(); -} - TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { auto iter = TYPE_MAP.find(onnx_type); if (iter == TYPE_MAP.end()) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 122fac9a49..8deb07d289 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -69,21 +69,11 @@ class OnnxModelParser : public ModelParser { ops::PrimitiveC *primitive_c, std::string loop_name); static STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, std::unordered_map *anf_nodes_map, const CNodePtr &cnode); - static STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, - std::unordered_map *anf_nodes_map, - ops::PrimitiveC *primitive_c); - static STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, - std::unordered_map *anf_nodes_map, - ops::PrimitiveC *primitive_c); - static STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, - std::unordered_map *anf_nodes_map, - ops::PrimitiveC *primitive_c, const std::string &name); STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector *quant_params); STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector *quant_params); STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); - static bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node); STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node, std::unordered_map *anf_nodes_map, const std::string &root_node_name); diff --git a/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc new file mode 100644 index 0000000000..ec068e21a5 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/matmul_add_fusion.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t AddInputSize = 3; +constexpr size_t MatMulInputSize = 3; +bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) { + MS_ASSERT(cnode != nullptr); + MS_ASSERT(index != nullptr); + if (cnode->size() != AddInputSize) { + return false; + } + size_t matmul_index = 0; + for (size_t i = 1; i < cnode->size(); ++i) { + if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) { + auto matmul_cnode = cnode->input(i)->cast(); + if (matmul_cnode->size() > MatMulInputSize) { + continue; + } + matmul_index = i; + break; + } + } + if (matmul_index == 0) { + return false; + } + *index = matmul_index; + return true; +} +} // namespace + +bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nulltr); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) { + continue; + } + size_t index = 0; + if (!CheckAndGetMatMulIndex(cnode, &index)) { + continue; + } + auto matmul_cnode = cnode->input(index)->cast(); + auto bias_node = cnode->input(AddInputSize - index); + if (!utils::isa(bias_node) || !bias_node->cast()->default_param()) { + continue; + } + matmul_cnode->add_input(bias_node); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + matmul_cnode->set_fullname_with_scope(node->fullname_with_scope()); + manager->Replace(node, matmul_cnode); + } + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.h b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.h new file mode 100644 index 0000000000..5513353019 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "tools/converter/converter_context.h" +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +class MatMulAddFusion : public Pass { + public: + MatMulAddFusion() : Pass("matmul_add_fusion") {} + ~MatMulAddFusion() override = default; + bool Run(const FuncGraphPtr &func_graph) override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_