|
|
|
@ -39,6 +39,27 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) {
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
for (auto &root : roots) {
|
|
|
|
|
auto tmp = DeepLinkedGraphSearch(root, include);
|
|
|
|
|
inputs.insert(inputs.end(), tmp.begin(), tmp.end());
|
|
|
|
|
}
|
|
|
|
|
return inputs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include,
|
|
|
|
|
const FuncGraphManagerPtr &mng) {
|
|
|
|
|
std::vector<AnfNodePtr> users;
|
|
|
|
|
for (auto &root : roots) {
|
|
|
|
|
auto tmp = DeepUsersSearch(root, include, mng);
|
|
|
|
|
users.insert(users.end(), tmp.begin(), tmp.end());
|
|
|
|
|
}
|
|
|
|
|
return users;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
|
|
|
|
|
#if ENABLE_D
|
|
|
|
|
std::vector<PrimitivePtr> basic_ops = {
|
|
|
|
@ -186,31 +207,33 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
|
|
|
|
std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) {
|
|
|
|
|
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) {
|
|
|
|
|
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
circle_nodes->clear();
|
|
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> cached_done_set;
|
|
|
|
|
auto cnode = check_node->cast<CNodePtr>();
|
|
|
|
|
const auto &inputs = cnode->inputs();
|
|
|
|
|
// there is a input not in fused_op_set, but the input depends on the fused_op_set
|
|
|
|
|
bool has_circle = false;
|
|
|
|
|
for (auto input : inputs) {
|
|
|
|
|
if (input->isa<CNode>() && !fused_op_set.count(input)) {
|
|
|
|
|
bool has_circle = false;
|
|
|
|
|
std::set<AnfNodePtr> done;
|
|
|
|
|
std::vector<AnfNodePtr> todos = {input};
|
|
|
|
|
while (!todos.empty()) {
|
|
|
|
|
auto node = todos.back();
|
|
|
|
|
todos.pop_back();
|
|
|
|
|
if (done.count(node) || cached_unconnected_set->count(node)) {
|
|
|
|
|
if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
done.insert(node);
|
|
|
|
|
if (fused_op_set.count(node)) {
|
|
|
|
|
has_circle = true;
|
|
|
|
|
*circle_node = node;
|
|
|
|
|
break;
|
|
|
|
|
circle_nodes->push_back(node);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
@ -224,13 +247,15 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (has_circle) {
|
|
|
|
|
return true;
|
|
|
|
|
cached_done_set.insert(done.begin(), done.end());
|
|
|
|
|
} else {
|
|
|
|
|
cached_unconnected_set->insert(done.begin(), done.end());
|
|
|
|
|
}
|
|
|
|
|
cached_unconnected_set->insert(done.begin(), done.end());
|
|
|
|
|
done.clear();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
return !circle_nodes->empty();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
|
|
|
|
@ -242,17 +267,19 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
|
|
|
|
|
}
|
|
|
|
|
return EXCLUDE;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> circle_nodes;
|
|
|
|
|
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
|
|
|
|
|
AnfNodePtr circle_node;
|
|
|
|
|
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node);
|
|
|
|
|
circle_nodes.clear();
|
|
|
|
|
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes);
|
|
|
|
|
// delete the circle node and the node which depend on the circle node in fused op
|
|
|
|
|
if (has_circle) {
|
|
|
|
|
auto mng = (*iter)->func_graph()->manager();
|
|
|
|
|
std::vector<AnfNodePtr> erase_nodes;
|
|
|
|
|
if (is_backward) {
|
|
|
|
|
erase_nodes = DeepUsersSearch(circle_node, include, mng);
|
|
|
|
|
erase_nodes = DeepUsersSearch(circle_nodes, include, mng);
|
|
|
|
|
} else {
|
|
|
|
|
erase_nodes = DeepLinkedGraphSearch(circle_node, include);
|
|
|
|
|
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
|
|
|
|
|
}
|
|
|
|
|
for (auto erase_node : erase_nodes) {
|
|
|
|
|
fused_op_set.erase(erase_node);
|
|
|
|
|