diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc index e1a781272c..694ee4d236 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc @@ -580,27 +580,43 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session:: } } +void BufferFusion::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector output_used_num{SizeToInt(manager->node_users()[relu_input].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input); + std::unordered_set record{cnode, relu_input}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); +} + void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(candidate_fusion); std::vector node_list = TopoSort(kernel_graph.get_return()); for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node)) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { continue; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); - } else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || - AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) { - auto relu_input = cnode->input(1); - if (relu_input->isa() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) { - MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion); - } else if (relu_input->isa() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) { - MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion); - } else if (relu_input->isa() && - AnfAlgo::GetCNodeName(relu_input) == prim::kPrimDepthwiseConv2dNative->name()) { - MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); + } else if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { + MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); + } + if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { + MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + } else if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { + MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + } } } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h index 11dafc2255..008d072ed3 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h @@ -53,6 +53,8 @@ class BufferFusion : public Pass { const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion, bool is_order); + void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); diff --git a/mindspore/ops/_op_impl/tbe/relu.py b/mindspore/ops/_op_impl/tbe/relu.py index 03cc381253..8a7d023afd 100644 --- a/mindspore/ops/_op_impl/tbe/relu.py +++ b/mindspore/ops/_op_impl/tbe/relu.py @@ -33,6 +33,7 @@ relu_op_info = TBERegOp("ReLU") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ .get_op_info() diff --git a/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test_py.cc b/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc similarity index 77% rename from tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test_py.cc rename to tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc index 68d4cf4033..aa548a5351 100644 --- a/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test_py.cc +++ b/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc @@ -26,15 +26,15 @@ namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; -class TestHWBufferFusionPy : public BackendCommon { +class TestHWBufferFusion : public BackendCommon { public: - TestHWBufferFusionPy() : get_py_fun_("gtest_input.pre_activate.buffer_fusion_test", true) {} - ~TestHWBufferFusionPy() override = default; + TestHWBufferFusion() : get_py_fun_("gtest_input.pre_activate.buffer_fusion_test", true) {} + ~TestHWBufferFusion() override = default; UT::PyFuncGraphFetcher get_py_fun_; }; -TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_1) { +TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_1", "before"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -90,7 +90,7 @@ TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_1) { EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_2) { +TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_2", "before"); std::vector shp{32, 10}; std::vector shp_bias{10}; @@ -179,7 +179,7 @@ TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_2) { EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWBufferFusionPy, test_tbe_reduce_eltwise_fusion) { +TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "before"); std::vector shp{32, 10}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -265,5 +265,71 @@ TEST_F(TestHWBufferFusionPy, test_tbe_reduce_eltwise_fusion) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_matmul_eltwise_fusion", "before"); + std::vector x_shp{2048, 768}; + std::vector y_shp{768, 768}; + auto x_abstract = std::make_shared(kFloat32, x_shp); + auto y_abstract = std::make_shared(kFloat32, y_shp); + AbstractBasePtrList args_spec_list{x_abstract, y_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto ret = kg->get_return(); + EXPECT_NE(ret->input(1), nullptr); + auto tuple = ret->input(1); + EXPECT_NE(tuple, nullptr); + auto cast = tuple->cast()->input(1); + EXPECT_NE(cast, nullptr); + auto relu = cast->cast()->input(1); + EXPECT_NE(relu, nullptr); + auto matmul = relu->cast()->input(1); + + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({"NC1HWC0"}); + builder.SetOutputsFormat({"NC1HWC0"}); + builder.SetInputsDeviceType({kFloat32->type_id()}); + builder.SetOutputsDeviceType({kFloat32->type_id()}); + builder.SetKernelType(KernelType::TBE_KERNEL); + builder.SetFusionType(kernel::FusionType::ELEMWISE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(KernelType::TBE_KERNEL); + relu->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu.get()); + + KernelBuildInfoBuilder builder2; + builder2.SetInputsFormat({"NC1HWC0", "NC1HWC0"}); + builder2.SetOutputsFormat({"NC1HWC0"}); + builder2.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); + builder2.SetOutputsDeviceType({kFloat32->type_id()}); + builder2.SetKernelType(KernelType::TBE_KERNEL); + builder2.SetFusionType(kernel::FusionType::OPAQUE); + builder2.SetProcessor(kernel::Processor::AICORE); + builder2.SetKernelType(KernelType::TBE_KERNEL); + matmul->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), matmul.get()); + + KernelBuildInfoBuilder builder1; + builder1.SetInputsFormat({"NC1HWC0"}); + builder1.SetOutputsFormat({"NC1HWC0"}); + builder1.SetInputsDeviceType({kFloat32->type_id()}); + builder1.SetOutputsDeviceType({kFloat16->type_id()}); + builder1.SetKernelType(KernelType::TBE_KERNEL); + builder1.SetFusionType(kernel::FusionType::OPAQUE); + builder1.SetProcessor(kernel::Processor::AICORE); + builder1.SetKernelType(KernelType::TBE_KERNEL); + cast->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast.get()); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto buffer_fusion_pass = std::make_shared(); + pm->AddPass(buffer_fusion_pass); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(kg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_matmul_eltwise_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/buffer_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/buffer_fusion_test.py index b4e4c2744e..12db1ce97b 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/buffer_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/buffer_fusion_test.py @@ -24,10 +24,12 @@ Reduce = P.ReduceOp() Biasadd = P.BiasAdd() Biasaddgrad = G.BiasAddGrad() Cast = P.Cast() +MatMul = P.MatMul() Fusion_relu_relu = Primitive('FusionOp_ReLU_ReLU') Fusion_biasadd = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAdd_ReLU_ReLU_ReLU') Fusion_biasaddgrad = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAddGrad_ReLU_ReLU_ReLU') +Fusion_matmul_relu = Primitive('FusionOp_MatMul_ReLU') Add = P.TensorAdd() Sub = P.Sub() @@ -133,3 +135,23 @@ def test_conv_singlein_fusion(tag): return tuple return fns[tag] + + +def test_tbe_matmul_eltwise_fusion(tag): + fns = FnDict() + + @fns + def before(x, y): + matmul = MatMul(x, y) + relu = Relu(matmul) + res = Cast(relu, mstype.float16) + return res + + @fns + def after(x, y): + fusion = Fusion_matmul_relu(x, y) + res = Cast(fusion) + tuple = make_tuple(res) + return tuple + + return fns[tag]