From 1c6c280da704491380bc15a55131868cfd7fa399 Mon Sep 17 00:00:00 2001 From: huanghui Date: Tue, 10 Nov 2020 20:19:45 +0800 Subject: [PATCH] fix unsorted_segment_sum_fission pass --- .../unsorted_segment_sum_fission.cc | 36 +++++++++++-------- .../unsorted_segment_sum_fission_test.cc | 12 ++++--- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc index 150f8aa05a..9a129c7911 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc @@ -74,6 +74,26 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice); return slice; } + +bool CheckInputs(const CNodePtr &origin_node) { + MS_EXCEPTION_IF_NULL(origin_node); + if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) { + MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum + << ". CNode= " << origin_node->DebugString(); + return false; + } + auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); + auto y_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1); + if (x_shape.empty() || y_shape.empty()) { + return false; + } + if (x_shape[x_shape.size() - 1] != 1) { + MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is " + << x_shape[x_shape.size() - 1]; + return false; + } + return x_shape.size() > y_shape.size(); +} } // namespace const BaseRef UnsortSegmentSumFission::DefinePattern() const { @@ -88,19 +108,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con MS_EXCEPTION_IF_NULL(node); auto origin_node = node->cast(); MS_EXCEPTION_IF_NULL(origin_node); - if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) { - MS_LOG(INFO) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum - << ". CNode= " << origin_node->DebugString(); - return nullptr; - } - auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); - if (input0_shape.size() < 2) { - MS_LOG(INFO) << "Input0's shape size less than 2, not optimize"; - return nullptr; - } - if (input0_shape[input0_shape.size() - 1] != 1) { - MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is " - << input0_shape[input0_shape.size() - 1]; + if (!CheckInputs(origin_node)) { return nullptr; } size_t pad_dim_size; @@ -110,7 +118,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con } else if (input_dtype == kNumberTypeFloat16) { pad_dim_size = 16; } else { - MS_LOG(INFO) << "UnsortedSegmentSum data type not in (float21, float16), no need change"; + MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change"; return nullptr; } diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc index 92a98338f9..ec820810cb 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc @@ -32,9 +32,11 @@ class TestHWUnsortedSegmentSumFission : public BackendCommon { TEST_F(TestHWUnsortedSegmentSumFission, test_fission) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before1"); EXPECT_NE(g, nullptr); - std::vector shp_x{16, 1}; + std::vector shp_x{3, 39, 1}; + std::vector shp_y{3}; auto x_abstract = std::make_shared(kFloat32, shp_x); - AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; + auto y_abstract = std::make_shared(kInt32, shp_y); + AbstractBasePtrList args_spec_list{x_abstract, y_abstract}; auto kg = GetKernelGraph(g, args_spec_list); auto optimizer = std::make_shared(); @@ -50,9 +52,11 @@ TEST_F(TestHWUnsortedSegmentSumFission, test_fission) { TEST_F(TestHWUnsortedSegmentSumFission, test_no_fission) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before2"); EXPECT_NE(g, nullptr); - std::vector shp_x{16, 2}; + std::vector shp_x{3, 39, 2}; + std::vector shp_y{3}; auto x_abstract = std::make_shared(kFloat32, shp_x); - AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; + auto y_abstract = std::make_shared(kInt32, shp_y); + AbstractBasePtrList args_spec_list{x_abstract, y_abstract}; auto kg = GetKernelGraph(g, args_spec_list); auto optimizer = std::make_shared();