!8090 refine remove_circle in ops_fusion

Merge pull request !8090 from lingyunli63/refine_remove_circle
pull/8090/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3c2819c04e

@ -186,8 +186,8 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI
}
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set) {
if (!check_node->isa<CNode>()) {
std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) {
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
return false;
}
@ -209,6 +209,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
done.insert(node);
if (fused_op_set.count(node)) {
has_circle = true;
*circle_node = node;
break;
}
@ -242,15 +243,16 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
return EXCLUDE;
};
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set);
AnfNodePtr circle_node;
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node);
// delete the circle node and the node which depend on the circle node in fused op
if (has_circle) {
auto mng = (*iter)->func_graph()->manager();
std::vector<AnfNodePtr> erase_nodes;
if (is_backward) {
erase_nodes = DeepUsersSearch(*iter, include, mng);
erase_nodes = DeepUsersSearch(circle_node, include, mng);
} else {
erase_nodes = DeepLinkedGraphSearch(*iter, include);
erase_nodes = DeepLinkedGraphSearch(circle_node, include);
}
for (auto erase_node : erase_nodes) {
fused_op_set.erase(erase_node);

Loading…
Cancel
Save