|
|
@ -156,6 +156,89 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
|
|
|
|
|
|
|
|
std::vector<CNodePtr> cnodes = {};
|
|
|
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
|
|
|
for (const auto anf : anf_nodes) {
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
|
|
|
|
if (anf->isa<CNode>()) {
|
|
|
|
|
|
|
|
cnodes.push_back(anf->cast<CNodePtr>());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return std::move(cnodes);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) {
|
|
|
|
|
|
|
|
size_t after_call_index = 0;
|
|
|
|
|
|
|
|
std::vector<std::vector<CNodePtr>> ret;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < cnodes.size(); i++) {
|
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) {
|
|
|
|
|
|
|
|
auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]);
|
|
|
|
|
|
|
|
// if graph is the true branch of while,no need split graph
|
|
|
|
|
|
|
|
if (call_kernel_graph.size() == 1 && call_kernel_graph[0] == cur_graph.parent_graph()) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i);
|
|
|
|
|
|
|
|
auto call_list = std::vector<CNodePtr>(1, cnodes[i]);
|
|
|
|
|
|
|
|
after_call_index = i + 1;
|
|
|
|
|
|
|
|
ret.push_back(prev_call_list);
|
|
|
|
|
|
|
|
ret.push_back(call_list);
|
|
|
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) {
|
|
|
|
|
|
|
|
ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end()));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return std::move(ret);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void UpdateRealInput(KernelGraph *graph) {
|
|
|
|
|
|
|
|
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
|
|
|
|
|
|
|
|
auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters,
|
|
|
|
|
|
|
|
const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graph);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
|
|
|
|
|
|
|
|
if (args.empty()) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (parameters.size() != args.size()) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
|
|
|
|
|
|
|
|
<< " and args size:" << args.size() << " not equal!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (size_t i = 0; i < parameters.size(); i++) {
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "bind paramreter:" << parameters[i]->DebugString() << " ,arg:" << args[i]->DebugString();
|
|
|
|
|
|
|
|
child_graph->SetRealInput(parameters[i], args[i]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
for (auto &call_node : call_nodes) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(call_node);
|
|
|
|
|
|
|
|
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node);
|
|
|
|
|
|
|
|
if (child_graphs.size() == 1) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graphs[0]);
|
|
|
|
|
|
|
|
bind_call_partial_with_parameter(
|
|
|
|
|
|
|
|
child_graphs[0]->inputs(), std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()),
|
|
|
|
|
|
|
|
child_graphs[0].get());
|
|
|
|
|
|
|
|
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
|
|
|
|
|
|
|
|
} else if (child_graphs.size() == 2) {
|
|
|
|
|
|
|
|
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> {
|
|
|
|
|
|
|
|
auto switch_node = call_node->input(1);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(switch_node);
|
|
|
|
|
|
|
|
auto switch_cnode = switch_node->cast<CNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(switch_cnode);
|
|
|
|
|
|
|
|
auto partial = switch_cnode->input(input_index);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial);
|
|
|
|
|
|
|
|
auto partial_cnode = partial->cast<CNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_cnode);
|
|
|
|
|
|
|
|
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end());
|
|
|
|
|
|
|
|
partial_cnode->set_inputs(
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
|
|
|
|
|
|
|
|
return std::move(ret);
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
|
|
|
|
|
|
|
|
bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
|
|
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
|
|
|
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
|
|
@ -171,7 +254,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|
|
|
MS_LOG(INFO) << "start";
|
|
|
|
MS_LOG(INFO) << "start";
|
|
|
|
auto graph = ConstructKernelGraph(func_graph);
|
|
|
|
auto graph = ConstructKernelGraph(func_graph);
|
|
|
|
// split switch
|
|
|
|
// split switch
|
|
|
|
SplitSwitch(graph.get());
|
|
|
|
SplitGraph(graph);
|
|
|
|
// insert goto labels and label_sets
|
|
|
|
// insert goto labels and label_sets
|
|
|
|
LinkChildGraphs(graph.get());
|
|
|
|
LinkChildGraphs(graph.get());
|
|
|
|
// resource initialize
|
|
|
|
// resource initialize
|
|
|
@ -1297,5 +1380,107 @@ void AscendSession::SyncInitialTenosrToDevice() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph,
|
|
|
|
|
|
|
|
const std::vector<CNodePtr> &list) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_kernel_graph);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
|
|
|
|
// count the output of every anf node
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> has_output_nodes;
|
|
|
|
|
|
|
|
for (auto &anf_node : list) {
|
|
|
|
|
|
|
|
for (auto &input : anf_node->inputs()) {
|
|
|
|
|
|
|
|
(void)has_output_nodes.insert(input);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
|
|
|
|
|
|
|
|
new_kernel_graph->set_return(anf_node->cast<CNodePtr>());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
|
|
|
|
// create new parameter from cnode
|
|
|
|
|
|
|
|
for (auto &anf_node : list) {
|
|
|
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
|
|
|
|
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
|
|
|
|
|
|
|
auto input = cnode->inputs()[input_idx];
|
|
|
|
|
|
|
|
if (!input->isa<CNode>()) {
|
|
|
|
|
|
|
|
cnode->set_input(input_idx, input);
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
|
|
|
|
|
|
|
|
auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
|
|
|
|
|
|
|
|
cnode->set_input(input_idx, new_parameter);
|
|
|
|
|
|
|
|
new_kernel_graph->SetRealInput(new_parameter, input);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
|
|
|
|
|
|
|
|
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
|
|
|
|
|
|
|
|
int output_idx = 0;
|
|
|
|
|
|
|
|
for (auto &anf_node : list) {
|
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
|
|
|
|
|
|
|
|
new_kernel_graph->set_return(anf_node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
|
|
|
|
|
|
|
|
make_tuple_inputs.push_back(anf_node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (new_kernel_graph->get_return() == nullptr) {
|
|
|
|
|
|
|
|
new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "end";
|
|
|
|
|
|
|
|
return new_kernel_graph;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
|
|
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
|
|
|
|
|
|
|
|
// update the root graph child graph order
|
|
|
|
|
|
|
|
graph->UpdateChildGraphOrder();
|
|
|
|
|
|
|
|
// get child list from current graph
|
|
|
|
|
|
|
|
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list);
|
|
|
|
|
|
|
|
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
|
|
|
|
|
|
|
|
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
|
|
|
|
|
|
|
|
return child_graph_list[0];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// create new child graph
|
|
|
|
|
|
|
|
auto child_graph = NewKernelGraph();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graph);
|
|
|
|
|
|
|
|
// create new value node to bind child graph
|
|
|
|
|
|
|
|
auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph));
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())),
|
|
|
|
|
|
|
|
graph_value_node};
|
|
|
|
|
|
|
|
// set the graph id of all node of child graph
|
|
|
|
|
|
|
|
for (auto &child_graph_node : child_graph_list) {
|
|
|
|
|
|
|
|
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
SplitKernelGraph(child_graph, child_graph_list);
|
|
|
|
|
|
|
|
auto new_call = graph->NewCNode(new_call_input);
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
|
|
|
|
|
|
|
|
return new_call;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
if (child_graph_lists.size() > 1) {
|
|
|
|
|
|
|
|
for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
|
|
|
|
|
|
|
|
auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]);
|
|
|
|
|
|
|
|
if (call_index == 0) {
|
|
|
|
|
|
|
|
auto new_return_primitive =
|
|
|
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())));
|
|
|
|
|
|
|
|
graph->set_return(graph->NewCNode({new_return_primitive, call_node}));
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
InsertDependToGraph(graph->graph_id(), call_node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
graph->UpdateChildGraphOrder();
|
|
|
|
|
|
|
|
UpdateRealInput(graph.get());
|
|
|
|
|
|
|
|
auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id()));
|
|
|
|
|
|
|
|
DumpIR(graph_name, graph);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end";
|
|
|
|
|
|
|
|
// recurse to split child graph
|
|
|
|
|
|
|
|
for (auto &child_graph : graph->child_graph_order()) {
|
|
|
|
|
|
|
|
SplitGraph(child_graph);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
} // namespace session
|
|
|
|
} // namespace session
|
|
|
|
} // namespace mindspore
|
|
|
|
} // namespace mindspore
|
|
|
|