|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <list>
|
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
#include "ir/meta_tensor.h"
|
|
|
|
|
#include "ir/anf.h"
|
|
|
|
@ -160,7 +161,7 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
|
|
|
|
|
std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
|
|
|
|
|
std::vector<CNodePtr> cnodes = {};
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
for (auto anf : anf_nodes) {
|
|
|
|
|
for (const auto &anf : anf_nodes) {
|
|
|
|
|
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
|
if (anf->isa<CNode>()) {
|
|
|
|
@ -192,6 +193,8 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
|
|
|
|
|
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
|
|
|
|
|
void UpdateRealInput(KernelGraph *graph) {
|
|
|
|
|
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
|
|
|
|
|
auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters,
|
|
|
|
@ -239,6 +242,15 @@ void UpdateRealInput(KernelGraph *graph) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurseToUpdateCallRealInput(KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_LOG(INFO) << "start graph id:" << graph->graph_id();
|
|
|
|
|
graph->UpdateCallRealInput();
|
|
|
|
|
for (auto &child_graph : graph->child_graph_order()) {
|
|
|
|
|
RecurseToUpdateCallRealInput(child_graph.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
|
|
|
@ -254,7 +266,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|
|
|
|
MS_LOG(INFO) << "start";
|
|
|
|
|
auto graph = ConstructKernelGraph(func_graph);
|
|
|
|
|
// split switch
|
|
|
|
|
SplitGraph(graph);
|
|
|
|
|
SplitGraphs(graph);
|
|
|
|
|
// insert goto labels and label_sets
|
|
|
|
|
LinkChildGraphs(NOT_NULL(graph));
|
|
|
|
|
// resource initialize
|
|
|
|
@ -1366,7 +1378,7 @@ void AscendSession::SyncInitialTenosrToDevice() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph,
|
|
|
|
|
KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
|
|
|
|
const std::vector<CNodePtr> &list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_kernel_graph);
|
|
|
|
|
MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
@ -1376,9 +1388,6 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
|
|
|
|
|
for (auto &input : anf_node->inputs()) {
|
|
|
|
|
(void)has_output_nodes.insert(input);
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
|
|
|
|
|
new_kernel_graph->set_return(anf_node->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
|
// create new parameter from cnode
|
|
|
|
@ -1386,6 +1395,7 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
|
|
|
|
auto input = cnode->inputs()[input_idx];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
|
|
if (!input->isa<CNode>()) {
|
|
|
|
|
cnode->set_input(input_idx, input);
|
|
|
|
|
continue;
|
|
|
|
@ -1417,6 +1427,12 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
|
|
|
|
|
return new_kernel_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) {
|
|
|
|
|
SplitGraph(root_graph);
|
|
|
|
|
// replace the real input if the real input is a call
|
|
|
|
|
RecurseToUpdateCallRealInput(root_graph.get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
|
|
|
|
|
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
@ -1426,6 +1442,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
|
|
|
|
|
// get child list from current graph
|
|
|
|
|
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list);
|
|
|
|
|
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
|
|
|
|
|
// if child graph list only has a call ,then return the exist call
|
|
|
|
|
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
|
|
|
|
|
return child_graph_list[0];
|
|
|
|
|
}
|
|
|
|
@ -1440,22 +1457,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
|
|
|
|
|
for (auto &child_graph_node : child_graph_list) {
|
|
|
|
|
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
|
|
|
|
|
}
|
|
|
|
|
SplitKernelGraph(child_graph, child_graph_list);
|
|
|
|
|
ConstructSplitedGraph(child_graph, child_graph_list);
|
|
|
|
|
auto new_call = graph->NewCNode(new_call_input);
|
|
|
|
|
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
|
|
|
|
|
return new_call;
|
|
|
|
|
};
|
|
|
|
|
if (child_graph_lists.size() > 1) {
|
|
|
|
|
std::list<AnfNodePtr> depend_input = {};
|
|
|
|
|
for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
|
|
|
|
|
auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]);
|
|
|
|
|
if (call_index == 0) {
|
|
|
|
|
depend_input.push_front(call_node);
|
|
|
|
|
}
|
|
|
|
|
depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))));
|
|
|
|
|
auto depend = graph->NewCNode(std::vector<AnfNodePtr>(depend_input.begin(), depend_input.end()));
|
|
|
|
|
auto new_return_primitive =
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())));
|
|
|
|
|
graph->set_return(graph->NewCNode({new_return_primitive, call_node}));
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
InsertDependToGraph(graph->graph_id(), call_node);
|
|
|
|
|
}
|
|
|
|
|
graph->set_return(graph->NewCNode({new_return_primitive, depend}));
|
|
|
|
|
}
|
|
|
|
|
graph->UpdateChildGraphOrder();
|
|
|
|
|
UpdateRealInput(graph.get());
|
|
|
|
|