|
|
|
@ -762,6 +762,39 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode) {
|
|
|
|
|
auto cnode_out_num = AnfAlgo::GetOutputTensorNum(cnode);
|
|
|
|
|
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
|
|
|
|
std::set<int> output_index_set;
|
|
|
|
|
// Assign Communicate Op Memory firstly.
|
|
|
|
|
for (const auto &node : nodes) {
|
|
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
|
|
|
|
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (item_with_index.first == cnode) {
|
|
|
|
|
output_index_set.insert(item_with_index.second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Node " << cnode->fullname_with_scope() << " has " << cnode_out_num
|
|
|
|
|
<< " outputs, in graph output num:" << output_index_set.size();
|
|
|
|
|
return cnode_out_num == output_index_set.size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
vector<CNodePtr>::iterator AscendStreamAssign::FindGraphEnd(vector<CNodePtr>::iterator begin,
|
|
|
|
|
vector<CNodePtr>::iterator end) {
|
|
|
|
|
while (begin != end) {
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrFpBpEnd, *begin)) {
|
|
|
|
|
MS_LOG(INFO) << "FpBp end op is " << (*begin)->fullname_with_scope();
|
|
|
|
|
return begin;
|
|
|
|
|
}
|
|
|
|
|
++begin;
|
|
|
|
|
}
|
|
|
|
|
return end;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// section5
|
|
|
|
|
void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
|
|
|
MS_LOG(INFO) << "Start";
|
|
|
|
@ -780,15 +813,23 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
|
|
|
|
|
while (it != cnodes.end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(*it);
|
|
|
|
|
if (IsHcom(*it)) {
|
|
|
|
|
auto cur_hcom_node = *it;
|
|
|
|
|
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
|
|
|
|
it = cnodes.insert(it + 1, send_cnode_ptr);
|
|
|
|
|
|
|
|
|
|
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), true);
|
|
|
|
|
auto target = FindTargetOp(it, cnodes.end(), cur_hcom_node, true);
|
|
|
|
|
if (target == cnodes.end()) {
|
|
|
|
|
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
|
|
|
|
|
<< ", can't find target for insert recv op, no insert send/recv";
|
|
|
|
|
it = cnodes.erase(it);
|
|
|
|
|
continue;
|
|
|
|
|
if (IsAllOutGraphOut(graph_ptr, cur_hcom_node)) {
|
|
|
|
|
// if hcom's all output is graph output, we need to insert send/recv to fpbp end in data sink mode
|
|
|
|
|
target = FindGraphEnd(it, cnodes.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (target == cnodes.end()) {
|
|
|
|
|
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
|
|
|
|
|
<< ", can't find target for insert recv op, no insert send/recv";
|
|
|
|
|
it = cnodes.erase(it);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// deal recv op
|
|
|
|
@ -824,7 +865,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// get the input which located in the lastr exe orders
|
|
|
|
|
// get the input which located in the last exe orders
|
|
|
|
|
vector<CNodePtr> inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
|
|
|
|
|
if (inputs_cnode.empty()) {
|
|
|
|
|
cnodes.emplace_back(cur_cnode_ptr);
|
|
|
|
|