remove multiple circles

pull/8389/head
lingyunli63 4 years ago
parent 5708bae7e7
commit dc95c63c03

@ -39,6 +39,27 @@
namespace mindspore {
namespace opt {
namespace {
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) {
std::vector<AnfNodePtr> inputs;
for (auto &root : roots) {
auto tmp = DeepLinkedGraphSearch(root, include);
inputs.insert(inputs.end(), tmp.begin(), tmp.end());
}
return inputs;
}
std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include,
const FuncGraphManagerPtr &mng) {
std::vector<AnfNodePtr> users;
for (auto &root : roots) {
auto tmp = DeepUsersSearch(root, include, mng);
users.insert(users.end(), tmp.begin(), tmp.end());
}
return users;
}
} // namespace
bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
#if ENABLE_D
std::vector<PrimitivePtr> basic_ops = {
@ -186,31 +207,33 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI
}
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) {
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) {
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
return false;
}
circle_nodes->clear();
std::set<AnfNodePtr> cached_done_set;
auto cnode = check_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
// there is a input not in fused_op_set, but the input depends on the fused_op_set
bool has_circle = false;
for (auto input : inputs) {
if (input->isa<CNode>() && !fused_op_set.count(input)) {
bool has_circle = false;
std::set<AnfNodePtr> done;
std::vector<AnfNodePtr> todos = {input};
while (!todos.empty()) {
auto node = todos.back();
todos.pop_back();
if (done.count(node) || cached_unconnected_set->count(node)) {
if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) {
continue;
}
done.insert(node);
if (fused_op_set.count(node)) {
has_circle = true;
*circle_node = node;
break;
circle_nodes->push_back(node);
continue;
}
if (node->isa<CNode>()) {
@ -224,13 +247,15 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
}
if (has_circle) {
return true;
cached_done_set.insert(done.begin(), done.end());
} else {
cached_unconnected_set->insert(done.begin(), done.end());
}
cached_unconnected_set->insert(done.begin(), done.end());
done.clear();
}
}
return false;
return !circle_nodes->empty();
}
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
@ -242,17 +267,19 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
}
return EXCLUDE;
};
std::vector<AnfNodePtr> circle_nodes;
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
AnfNodePtr circle_node;
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node);
circle_nodes.clear();
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes);
// delete the circle node and the node which depend on the circle node in fused op
if (has_circle) {
auto mng = (*iter)->func_graph()->manager();
std::vector<AnfNodePtr> erase_nodes;
if (is_backward) {
erase_nodes = DeepUsersSearch(circle_node, include, mng);
erase_nodes = DeepUsersSearch(circle_nodes, include, mng);
} else {
erase_nodes = DeepLinkedGraphSearch(circle_node, include);
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
}
for (auto erase_node : erase_nodes) {
fused_op_set.erase(erase_node);

Loading…
Cancel
Save