|
|
|
@ -74,6 +74,26 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice);
|
|
|
|
|
return slice;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CheckInputs(const CNodePtr &origin_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(origin_node);
|
|
|
|
|
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) {
|
|
|
|
|
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum
|
|
|
|
|
<< ". CNode= " << origin_node->DebugString();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
|
|
|
|
auto y_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
|
|
|
|
|
if (x_shape.empty() || y_shape.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (x_shape[x_shape.size() - 1] != 1) {
|
|
|
|
|
MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
|
|
|
|
|
<< x_shape[x_shape.size() - 1];
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return x_shape.size() > y_shape.size();
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
const BaseRef UnsortSegmentSumFission::DefinePattern() const {
|
|
|
|
@ -88,19 +108,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto origin_node = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(origin_node);
|
|
|
|
|
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) {
|
|
|
|
|
MS_LOG(INFO) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum
|
|
|
|
|
<< ". CNode= " << origin_node->DebugString();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
|
|
|
|
if (input0_shape.size() < 2) {
|
|
|
|
|
MS_LOG(INFO) << "Input0's shape size less than 2, not optimize";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (input0_shape[input0_shape.size() - 1] != 1) {
|
|
|
|
|
MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
|
|
|
|
|
<< input0_shape[input0_shape.size() - 1];
|
|
|
|
|
if (!CheckInputs(origin_node)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
size_t pad_dim_size;
|
|
|
|
@ -110,7 +118,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
|
|
|
|
|
} else if (input_dtype == kNumberTypeFloat16) {
|
|
|
|
|
pad_dim_size = 16;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(INFO) << "UnsortedSegmentSum data type not in (float21, float16), no need change";
|
|
|
|
|
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|