|
|
@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
|
|
|
|
return output_nodes;
|
|
|
|
return output_nodes;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Find control_depend real input nodes.
|
|
|
|
|
|
|
|
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(visited);
|
|
|
|
|
|
|
|
if (visited->find(anf_node) != visited->end()) {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
visited->insert(anf_node);
|
|
|
|
|
|
|
|
if (AnfAlgo::IsRealKernel(anf_node)) {
|
|
|
|
|
|
|
|
result->emplace_back(anf_node);
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!anf_node->isa<CNode>()) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
|
|
|
if (cnode->inputs().empty()) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto input0 = cnode->input(0);
|
|
|
|
|
|
|
|
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
|
|
|
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
|
|
|
|
GetAllFatherRealNode(cnode->input(i), result, visited);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
|
|
|
|
|
|
|
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
|
|
|
|
|
|
|
|
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
|
|
|
|
|
|
|
|
if (cnode->inputs().size() != kDependInputSize) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
|
|
|
|
|
|
|
|
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// update the depend relations of control depend
|
|
|
|
// update the depend relations of control depend
|
|
|
|
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
|
|
|
|
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
|
|
|
|
for (const auto &node : depends) {
|
|
|
|
for (const auto &node : depends) {
|
|
|
@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
|
|
|
|
if (depend_node->isa<Parameter>() && depend_mode == 1) {
|
|
|
|
if (depend_node->isa<Parameter>() && depend_mode == 1) {
|
|
|
|
depend_nodes = GetOutputNodes(depend_node);
|
|
|
|
depend_nodes = GetOutputNodes(depend_node);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto &first_node : prior_nodes) {
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> real_prior_nodes;
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> prior_visited;
|
|
|
|
|
|
|
|
for (const auto &tmp : prior_nodes) {
|
|
|
|
|
|
|
|
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> real_depend_nodes;
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> depend_visited;
|
|
|
|
|
|
|
|
for (const auto &tmp : depend_nodes) {
|
|
|
|
|
|
|
|
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &first_node : real_prior_nodes) {
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto &second_node : depend_nodes) {
|
|
|
|
for (auto &second_node : real_depend_nodes) {
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|