|
|
|
@ -97,6 +97,25 @@ void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CheckSegments(size_t segments, size_t communication_op_node_size, std::vector<size_t> *segment_index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(segment_index);
|
|
|
|
|
if (segments >= communication_op_node_size) {
|
|
|
|
|
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments
|
|
|
|
|
<< ", communication_op_node_size=" << communication_op_node_size;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "the last segment index is invalid.";
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < segments - 1; ++i) {
|
|
|
|
|
if (segment_index->at(i) > segment_index->at(i + 1)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ "
|
|
|
|
|
<< i + 1 << "]=" << segment_index->at(i + 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
|
|
|
|
@ -137,22 +156,8 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
|
|
|
|
segment_index->push_back(communication_op_node_size - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (segments >= communication_op_node_size) {
|
|
|
|
|
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments
|
|
|
|
|
<< ", communication_op_node_size=" << communication_op_node_size;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "the last segment index is invalid.";
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < segments - 1; ++i) {
|
|
|
|
|
if (segment_index->at(i) > segment_index->at(i + 1)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ "
|
|
|
|
|
<< i + 1 << "]=" << segment_index->at(i + 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
*segment_num = segments;
|
|
|
|
|
return true;
|
|
|
|
|
return CheckSegments(segments, communication_op_node_size, segment_index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
|
|
|
|
|