|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -34,12 +34,12 @@ namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
|
|
|
|
|
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
auto func_graph = anf_node->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
|
|
|
|
|
CNodePtr node = kernel_graph->NewCNode(cnode->inputs());
|
|
|
|
|
CNodePtr node = func_graph->NewCNode(cnode->inputs());
|
|
|
|
|
node->set_abstract(cnode->abstract());
|
|
|
|
|
node->set_forward(cnode->forward().first, cnode->forward().second);
|
|
|
|
|
node->set_inputs_value(cnode->inputs_value());
|
|
|
|
@ -90,19 +90,38 @@ bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
|
|
|
|
|
changed = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mng->RemoveRoots();
|
|
|
|
|
mng->KeepRoots({func_graph});
|
|
|
|
|
if (changed) {
|
|
|
|
|
mng->RemoveRoots();
|
|
|
|
|
mng->KeepRoots({func_graph});
|
|
|
|
|
}
|
|
|
|
|
return changed;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
auto mng = func_graph->manager();
|
|
|
|
|
if (mng == nullptr) {
|
|
|
|
|
mng = Manage(func_graph, true);
|
|
|
|
|
func_graph->set_manager(mng);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto todos = TopoSort(func_graph->get_return());
|
|
|
|
|
bool result = false;
|
|
|
|
|
bool changed;
|
|
|
|
|
do {
|
|
|
|
|
changed = Process(func_graph);
|
|
|
|
|
result |= changed;
|
|
|
|
|
} while (changed);
|
|
|
|
|
for (const auto &anf_node : todos) {
|
|
|
|
|
if (AnfAlgo::IsGraphKernel(anf_node)) {
|
|
|
|
|
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
|
|
|
|
|
bool changed = false;
|
|
|
|
|
do {
|
|
|
|
|
changed = Process(sub_graph);
|
|
|
|
|
result = result || changed;
|
|
|
|
|
} while (changed);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (result) {
|
|
|
|
|
mng->RemoveRoots();
|
|
|
|
|
mng->KeepRoots({func_graph});
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|