enable loop sink when no getnext in execution orders

pull/10391/head
laiyongqiang 5 years ago
parent 237faca57e
commit d417dddb24

@ -762,6 +762,39 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
return false;
}
bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode) {
auto cnode_out_num = AnfAlgo::GetOutputTensorNum(cnode);
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
std::set<int> output_index_set;
// Assign Communicate Op Memory firstly.
for (const auto &node : nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
continue;
}
if (item_with_index.first == cnode) {
output_index_set.insert(item_with_index.second);
}
}
MS_LOG(INFO) << "Node " << cnode->fullname_with_scope() << " has " << cnode_out_num
<< " outputs, in graph output num:" << output_index_set.size();
return cnode_out_num == output_index_set.size();
}
vector<CNodePtr>::iterator AscendStreamAssign::FindGraphEnd(vector<CNodePtr>::iterator begin,
vector<CNodePtr>::iterator end) {
while (begin != end) {
if (AnfAlgo::HasNodeAttr(kAttrFpBpEnd, *begin)) {
MS_LOG(INFO) << "FpBp end op is " << (*begin)->fullname_with_scope();
return begin;
}
++begin;
}
return end;
}
// section5
void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "Start";
@ -780,15 +813,23 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
while (it != cnodes.end()) {
MS_EXCEPTION_IF_NULL(*it);
if (IsHcom(*it)) {
auto cur_hcom_node = *it;
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
it = cnodes.insert(it + 1, send_cnode_ptr);
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), true);
auto target = FindTargetOp(it, cnodes.end(), cur_hcom_node, true);
if (target == cnodes.end()) {
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
<< ", can't find target for insert recv op, no insert send/recv";
it = cnodes.erase(it);
continue;
if (IsAllOutGraphOut(graph_ptr, cur_hcom_node)) {
// if hcom's all output is graph output, we need to insert send/recv to fpbp end in data sink mode
target = FindGraphEnd(it, cnodes.end());
}
if (target == cnodes.end()) {
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
<< ", can't find target for insert recv op, no insert send/recv";
it = cnodes.erase(it);
continue;
}
}
// deal recv op
@ -824,7 +865,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
continue;
}
// get the input which located in the lastr exe orders
// get the input which located in the last exe orders
vector<CNodePtr> inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
if (inputs_cnode.empty()) {
cnodes.emplace_back(cur_cnode_ptr);

@ -212,6 +212,8 @@ class AscendStreamAssign {
std::map<CNodePtr, CNodePtr> event_map_{};
std::set<uint32_t> middle_active_streams_{};
// new policy end
bool IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode);
vector<CNodePtr>::iterator FindGraphEnd(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end);
};
} // namespace ascend
} // namespace device

File diff suppressed because it is too large Load Diff

@ -86,7 +86,8 @@ class KernelAdjust {
void LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs);
void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info,
NotNull<session::KernelGraph *> kernel_graph_ptr);
bool ExitIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool ExistIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
bool ExistGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
};
} // namespace device
} // namespace mindspore

@ -315,6 +315,7 @@ constexpr auto kAttrOutputUsedNum = "output_used_num";
constexpr auto kAttrHasBias = "has_bias";
constexpr auto kAttrN = "n";
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
constexpr auto kAttrFpBpEnd = "fpbp_end";
constexpr auto kAttrFusion = "fusion";
constexpr auto kAttrGroup = "group";
constexpr auto kAttrOp = "op";

Loading…
Cancel
Save