!8431 Add more check in unsorted_segment_sum_fission pass to avoid doing wrong optimize which cause runtime error

From: @irmo
Reviewed-by: 
Signed-off-by:
pull/8431/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 222a0bccf4

@ -74,6 +74,26 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice);
return slice;
}
bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) {
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum
<< ". CNode= " << origin_node->DebugString();
return false;
}
auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
auto y_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
if (x_shape.empty() || y_shape.empty()) {
return false;
}
if (x_shape[x_shape.size() - 1] != 1) {
MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
<< x_shape[x_shape.size() - 1];
return false;
}
return x_shape.size() > y_shape.size();
}
} // namespace
const BaseRef UnsortSegmentSumFission::DefinePattern() const {
@ -88,19 +108,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(node);
auto origin_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_node);
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) {
MS_LOG(INFO) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum
<< ". CNode= " << origin_node->DebugString();
return nullptr;
}
auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
if (input0_shape.size() < 2) {
MS_LOG(INFO) << "Input0's shape size less than 2, not optimize";
return nullptr;
}
if (input0_shape[input0_shape.size() - 1] != 1) {
MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
<< input0_shape[input0_shape.size() - 1];
if (!CheckInputs(origin_node)) {
return nullptr;
}
size_t pad_dim_size;
@ -110,7 +118,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
} else if (input_dtype == kNumberTypeFloat16) {
pad_dim_size = 16;
} else {
MS_LOG(INFO) << "UnsortedSegmentSum data type not in (float21, float16), no need change";
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change";
return nullptr;
}

@ -32,9 +32,11 @@ class TestHWUnsortedSegmentSumFission : public BackendCommon {
TEST_F(TestHWUnsortedSegmentSumFission, test_fission) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before1");
EXPECT_NE(g, nullptr);
std::vector<int64_t> shp_x{16, 1};
std::vector<int64_t> shp_x{3, 39, 1};
std::vector<int64_t> shp_y{3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
@ -50,9 +52,11 @@ TEST_F(TestHWUnsortedSegmentSumFission, test_fission) {
TEST_F(TestHWUnsortedSegmentSumFission, test_no_fission) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before2");
EXPECT_NE(g, nullptr);
std::vector<int64_t> shp_x{16, 2};
std::vector<int64_t> shp_x{3, 39, 2};
std::vector<int64_t> shp_y{3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();

Loading…
Cancel
Save