!8090 refine remove_circle in ops_fusion

Merge pull request !8090 from lingyunli63/refine_remove_circle
pull/8090/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3c2819c04e

@ -186,8 +186,8 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI
} }
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set) { std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) {
if (!check_node->isa<CNode>()) { if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
return false; return false;
} }
@ -209,6 +209,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
done.insert(node); done.insert(node);
if (fused_op_set.count(node)) { if (fused_op_set.count(node)) {
has_circle = true; has_circle = true;
*circle_node = node;
break; break;
} }
@ -242,15 +243,16 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
return EXCLUDE; return EXCLUDE;
}; };
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set); AnfNodePtr circle_node;
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node);
// delete the circle node and the node which depend on the circle node in fused op // delete the circle node and the node which depend on the circle node in fused op
if (has_circle) { if (has_circle) {
auto mng = (*iter)->func_graph()->manager(); auto mng = (*iter)->func_graph()->manager();
std::vector<AnfNodePtr> erase_nodes; std::vector<AnfNodePtr> erase_nodes;
if (is_backward) { if (is_backward) {
erase_nodes = DeepUsersSearch(*iter, include, mng); erase_nodes = DeepUsersSearch(circle_node, include, mng);
} else { } else {
erase_nodes = DeepLinkedGraphSearch(*iter, include); erase_nodes = DeepLinkedGraphSearch(circle_node, include);
} }
for (auto erase_node : erase_nodes) { for (auto erase_node : erase_nodes) {
fused_op_set.erase(erase_node); fused_op_set.erase(erase_node);

Loading…
Cancel
Save