|
|
|
@ -206,39 +206,40 @@ static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePt
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters, const std::vector<AnfNodePtr> &args,
|
|
|
|
|
KernelGraph *child_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graph);
|
|
|
|
|
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
|
|
|
|
|
if (args.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (parameters.size() != args.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
|
|
|
|
|
<< " and args size:" << args.size() << " not equal!";
|
|
|
|
|
}
|
|
|
|
|
child_graph->SetExecOrderByDefault();
|
|
|
|
|
for (size_t i = 0; i < parameters.size(); i++) {
|
|
|
|
|
if (args[i] == parameters[i]) {
|
|
|
|
|
child_graph->SetRealInput(parameters[i], args[i]);
|
|
|
|
|
MS_LOG(INFO) << "Parameter and arg are same";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// if arg is a parameter ,then reuse this parameter
|
|
|
|
|
if (args[i]->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id()
|
|
|
|
|
<< " reuse parameter:" << args[i]->DebugString()
|
|
|
|
|
<< " of graph:" << AnfAlgo::GetGraphId(args[i].get());
|
|
|
|
|
child_graph->ReplaceNode(parameters[i], args[i]);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
child_graph->SetRealInput(parameters[i], args[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
|
|
|
|
|
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
|
|
|
|
|
static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
|
|
|
|
|
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
|
|
|
|
|
auto bind_call_arg_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters,
|
|
|
|
|
const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graph);
|
|
|
|
|
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
|
|
|
|
|
if (args.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (parameters.size() != args.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
|
|
|
|
|
<< " and args size:" << args.size() << " not equal!";
|
|
|
|
|
}
|
|
|
|
|
child_graph->SetExecOrderByDefault();
|
|
|
|
|
for (size_t i = 0; i < parameters.size(); i++) {
|
|
|
|
|
if (args[i] == parameters[i]) {
|
|
|
|
|
child_graph->SetRealInput(parameters[i], args[i]);
|
|
|
|
|
MS_LOG(INFO) << "Parameter and arg are same";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// if arg is a parameter ,then reuse this parameter
|
|
|
|
|
if (args[i]->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id()
|
|
|
|
|
<< " reuse parameter:" << args[i]->DebugString()
|
|
|
|
|
<< " of graph:" << AnfAlgo::GetGraphId(args[i].get());
|
|
|
|
|
child_graph->ReplaceNode(parameters[i], args[i]);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
child_graph->SetRealInput(parameters[i], args[i]);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
for (auto &call_node : call_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(call_node);
|
|
|
|
|
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node);
|
|
|
|
@ -247,7 +248,7 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
|
|
|
|
|
std::vector<AnfNodePtr> real_args =
|
|
|
|
|
std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end());
|
|
|
|
|
std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs();
|
|
|
|
|
bind_call_arg_with_parameter(child_inputs, real_args, child_graphs[0].get());
|
|
|
|
|
BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get());
|
|
|
|
|
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
|
|
|
|
|
} else if (child_graphs.size() == 2) {
|
|
|
|
|
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> {
|
|
|
|
@ -264,8 +265,8 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
|
|
|
|
|
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
|
|
|
|
|
return ret;
|
|
|
|
|
};
|
|
|
|
|
bind_call_arg_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
|
|
|
|
|
bind_call_arg_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
|
|
|
|
|
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
|
|
|
|
|
BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1429,10 +1430,7 @@ void AscendSession::SyncInitialTenosrToDevice() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
|
|
|
|
const std::vector<CNodePtr> &list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_kernel_graph);
|
|
|
|
|
MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
|
static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list) {
|
|
|
|
|
// count the output of every anf node
|
|
|
|
|
std::set<AnfNodePtr> has_output_nodes;
|
|
|
|
|
for (auto &anf_node : list) {
|
|
|
|
@ -1440,6 +1438,28 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
|
|
|
|
(void)has_output_nodes.insert(input);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
|
|
|
|
|
int output_idx = 0;
|
|
|
|
|
for (auto &anf_node : list) {
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
|
|
|
|
|
new_kernel_graph->set_return(anf_node);
|
|
|
|
|
}
|
|
|
|
|
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
|
|
|
|
|
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
|
|
|
|
|
make_tuple_inputs.push_back(anf_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (new_kernel_graph->get_return() == nullptr) {
|
|
|
|
|
new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
|
|
|
|
const std::vector<CNodePtr> &list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_kernel_graph);
|
|
|
|
|
MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
|
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
|
std::vector<AnfNodePtr> call_node_inputs;
|
|
|
|
|
std::vector<AnfNodePtr> new_graph_inputs;
|
|
|
|
@ -1479,22 +1499,9 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs);
|
|
|
|
|
graph_inputs->clear();
|
|
|
|
|
std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs));
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
|
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
|
|
|
|
|
int output_idx = 0;
|
|
|
|
|
for (auto &anf_node : list) {
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
|
|
|
|
|
new_kernel_graph->set_return(anf_node);
|
|
|
|
|
}
|
|
|
|
|
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
|
|
|
|
|
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
|
|
|
|
|
make_tuple_inputs.push_back(anf_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (new_kernel_graph->get_return() == nullptr) {
|
|
|
|
|
new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
|
|
|
|
|
}
|
|
|
|
|
ConstructSplitedGraphOutput(new_kernel_graph, list);
|
|
|
|
|
MS_LOG(INFO) << "end";
|
|
|
|
|
return call_node_inputs;
|
|
|
|
|
}
|
|
|
|
@ -1516,6 +1523,30 @@ void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
|
|
|
|
|
RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
|
|
|
|
|
const std::vector<CNodePtr> &child_graph_list) {
|
|
|
|
|
// if child graph list only has a call ,then return the exist call
|
|
|
|
|
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
|
|
|
|
|
return child_graph_list[0];
|
|
|
|
|
}
|
|
|
|
|
// create new child graph
|
|
|
|
|
auto child_graph = NewKernelGraph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graph);
|
|
|
|
|
// create new value node to bind child graph
|
|
|
|
|
auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph));
|
|
|
|
|
std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())),
|
|
|
|
|
graph_value_node};
|
|
|
|
|
// set the graph id of all node of child graph
|
|
|
|
|
for (auto &child_graph_node : child_graph_list) {
|
|
|
|
|
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
|
|
|
|
|
}
|
|
|
|
|
auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
|
|
|
|
|
std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
|
|
|
|
|
auto new_call = graph->NewCNode(new_call_input);
|
|
|
|
|
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
|
|
|
|
|
return new_call;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
|
|
|
|
|
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
|
|
|
|
|
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
|
|
|
|
@ -1523,32 +1554,10 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
|
|
|
|
|
AscendControlParser::UpdateChildGraphOrder(graph);
|
|
|
|
|
// get child list from current graph
|
|
|
|
|
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims);
|
|
|
|
|
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
|
|
|
|
|
// if child graph list only has a call ,then return the exist call
|
|
|
|
|
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
|
|
|
|
|
return child_graph_list[0];
|
|
|
|
|
}
|
|
|
|
|
// create new child graph
|
|
|
|
|
auto child_graph = NewKernelGraph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graph);
|
|
|
|
|
// create new value node to bind child graph
|
|
|
|
|
auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph));
|
|
|
|
|
std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())),
|
|
|
|
|
graph_value_node};
|
|
|
|
|
// set the graph id of all node of child graph
|
|
|
|
|
for (auto &child_graph_node : child_graph_list) {
|
|
|
|
|
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
|
|
|
|
|
}
|
|
|
|
|
auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
|
|
|
|
|
std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
|
|
|
|
|
auto new_call = graph->NewCNode(new_call_input);
|
|
|
|
|
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
|
|
|
|
|
return new_call;
|
|
|
|
|
};
|
|
|
|
|
if (child_graph_lists.size() > 1) {
|
|
|
|
|
std::list<AnfNodePtr> depend_input = {};
|
|
|
|
|
for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
|
|
|
|
|
auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]);
|
|
|
|
|
auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(call_node);
|
|
|
|
|
// if call node is the last call of true graph,no need create child graph after that
|
|
|
|
|
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
|
|
|
|
@ -1605,6 +1614,5 @@ void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const Not
|
|
|
|
|
RecurseCompileGraph(NOT_NULL(child_graph), memo);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace session
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|