|
|
|
@ -186,8 +186,8 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
|
|
|
|
std::set<AnfNodePtr> *cached_unconnected_set) {
|
|
|
|
|
if (!check_node->isa<CNode>()) {
|
|
|
|
|
std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) {
|
|
|
|
|
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -209,6 +209,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
|
|
|
|
|
done.insert(node);
|
|
|
|
|
if (fused_op_set.count(node)) {
|
|
|
|
|
has_circle = true;
|
|
|
|
|
*circle_node = node;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -242,15 +243,16 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
|
|
|
|
|
return EXCLUDE;
|
|
|
|
|
};
|
|
|
|
|
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
|
|
|
|
|
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set);
|
|
|
|
|
AnfNodePtr circle_node;
|
|
|
|
|
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node);
|
|
|
|
|
// delete the circle node and the node which depend on the circle node in fused op
|
|
|
|
|
if (has_circle) {
|
|
|
|
|
auto mng = (*iter)->func_graph()->manager();
|
|
|
|
|
std::vector<AnfNodePtr> erase_nodes;
|
|
|
|
|
if (is_backward) {
|
|
|
|
|
erase_nodes = DeepUsersSearch(*iter, include, mng);
|
|
|
|
|
erase_nodes = DeepUsersSearch(circle_node, include, mng);
|
|
|
|
|
} else {
|
|
|
|
|
erase_nodes = DeepLinkedGraphSearch(*iter, include);
|
|
|
|
|
erase_nodes = DeepLinkedGraphSearch(circle_node, include);
|
|
|
|
|
}
|
|
|
|
|
for (auto erase_node : erase_nodes) {
|
|
|
|
|
fused_op_set.erase(erase_node);
|
|
|
|
|