!1483 Clean code of pr1459

Merge pull request !1483 from zhoufeng/code-clean
pull/1483/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d953b2b5ab

@ -19,6 +19,17 @@
#include "session/ascend_control_parser.h" #include "session/ascend_control_parser.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1;
static constexpr size_t kCNodeSwitchCond = 1;
static constexpr size_t kCNodeSwitchTrue = 2;
static constexpr size_t kCNodeSwitchFalse = 3;
static constexpr size_t kCNodeSwitchLength = 4;
static constexpr size_t kCNodePartialLength = 2;
static constexpr size_t kCNodePartialFunc = 1;
static constexpr size_t kCNodeSwitchLayerBranch = 2;
static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace mindspore { namespace mindspore {
namespace session { namespace session {
@ -61,7 +72,7 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
ChildGraphDataAssign(graph_id_map); ChildGraphDataAssign(graph_id_map);
} }
CNodePtr AscendControlParser::GetNextRealKernel(std::vector<CNodePtr> list, size_t start) { CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
for (size_t i = start; i < list.size() - 1; ++i) { for (size_t i = start; i < list.size() - 1; ++i) {
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
return list[i]; return list[i];
@ -83,11 +94,11 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
memo->insert(kg.get()); memo->insert(kg.get());
// 2. args replace placeholder // 2. args replace placeholder
LinkParentGraph(kg, last_node, last_label, memo); LinkParentGraph(kg, last_node, last_label);
// 3. topological sort // 3. topological sort
kg->SetExecOrderByDefault(); kg->SetExecOrderByDefault();
std::vector<CNodePtr> nodes = kg->execution_order(); const std::vector<CNodePtr> &nodes = kg->execution_order();
if (nodes.empty()) { if (nodes.empty()) {
MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!";
} }
@ -149,9 +160,9 @@ void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg,
} }
void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo) { const CNodePtr &last_label) {
auto origin_return = kg->get_return(); auto origin_return = kg->get_return();
std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs(); const std::vector<AnfNodePtr> &origin_return_inputs = origin_return->inputs();
// if entry graph, replace return with make_tuple // if entry graph, replace return with make_tuple
if (from_graph_call_node == nullptr || last_label == nullptr) { if (from_graph_call_node == nullptr || last_label == nullptr) {
MS_LOG(INFO) << kg->ToString() << " is entry graph."; MS_LOG(INFO) << kg->ToString() << " is entry graph.";
@ -173,7 +184,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
MS_LOG(INFO) << "process call func " << cur_node->DebugString(); MS_LOG(INFO) << "process call func " << cur_node->DebugString();
// 1 get kernel graph // 1 get kernel graph
auto origin_inputs = cur_node->inputs(); const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))}; std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) { if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
@ -217,15 +228,14 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
} }
// 3 recurse sub graph // 3 recurse sub graph
auto origin_switch_inputs = cur_node->inputs(); const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_switch_inputs = { std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]}; origin_switch_inputs[kCNodeSwitchCond]};
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
// 3.1 branch kernel graph and args // 3.1 branch kernel graph and args
CNodePtr partial;
KernelGraphPtr branch_fg; KernelGraphPtr branch_fg;
std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
// 3.2 recurse sub graph // 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label); new_switch_inputs.push_back(branch_label);
@ -249,9 +259,9 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch);
MS_EXCEPTION_IF_NULL(branch_tuple); MS_EXCEPTION_IF_NULL(branch_tuple);
if (!branch_tuple->isa<CNode>()) { if (!branch_tuple->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode";
} }
auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs(); const std::vector<AnfNodePtr> &branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
// 1 return label // 1 return label
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
// 2 add depend relationship // 2 add depend relationship
@ -260,15 +270,14 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
} }
// 3 recurse sub graph // 3 recurse sub graph
auto origin_switch_inputs = cur_node->inputs(); const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_switch_inputs = { std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]}; origin_switch_inputs[kCNodeSwitchCond]};
for (size_t i = 0; i < branch_partial.size(); ++i) { for (size_t i = 0; i < branch_partial.size(); ++i) {
// 3.1 branch kernel graph and args // 3.1 branch kernel graph and args
CNodePtr partial;
KernelGraphPtr branch_fg; KernelGraphPtr branch_fg;
std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
// 3.2 recurse sub graph // 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label); new_switch_inputs.push_back(branch_label);
@ -315,18 +324,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
InsertDependToGraph(kg, NOT_NULL(assign_node)); InsertDependToGraph(kg, NOT_NULL(assign_node));
} }
NotNull<AnfNodePtr> AscendControlParser::GetRealInput(NotNull<KernelGraphPtr> from_graph,
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> param) {
std::set<AnfNodePtr> args_list = to_graph->GetRealInput(param);
for (auto arg : args_list) {
if (arg->func_graph() == from_graph.get()) {
return NOT_NULL(arg);
}
}
MS_LOG(EXCEPTION) << to_graph->ToString() << " input " << param->DebugString() << " not from "
<< from_graph->ToString();
}
void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph, void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param) { NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param) {
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) { if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) {
@ -369,10 +366,10 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
return {}; return {};
} }
memo->insert(graph.get()); memo->insert(graph.get());
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
graph->SetExecOrderByDefault(); graph->SetExecOrderByDefault();
std::vector<CNodePtr> cnodes = graph->execution_order(); const std::vector<CNodePtr> &cnodes = graph->execution_order();
std::map<uint32_t, CNodePtr> label_map; std::map<uint32_t, CNodePtr> label_map;
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map; std::map<CNodePtr, std::vector<uint32_t>> label_switch_map;
std::tie(label_map, label_switch_map) = GetLabelNode(cnodes); std::tie(label_map, label_switch_map) = GetLabelNode(cnodes);
@ -388,10 +385,10 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
std::find_if(label_map.begin(), label_map.end(), std::find_if(label_map.begin(), label_map.end(),
[node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; }); [node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; });
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
if (!CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) { if (label_iter == label_map.end() || !CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail"; MS_LOG(EXCEPTION) << "Check label index fail";
} }
auto child_graph = graph->child_graph_order()[label_iter->first]; auto child_graph = child_graph_order[label_iter->first];
if (child_graph == graph->parent_graph()) { if (child_graph == graph->parent_graph()) {
continue; continue;
} }
@ -407,7 +404,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) { if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail"; MS_LOG(EXCEPTION) << "Check label index fail";
} }
auto child_graph = graph->child_graph_order()[label_iter->first + i]; auto child_graph = child_graph_order[label_iter->first + i];
if (child_graph == graph->parent_graph()) { if (child_graph == graph->parent_graph()) {
continue; continue;
} }
@ -426,10 +423,11 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
NotNull<KernelGraphPtr> graph) { NotNull<KernelGraphPtr> graph) {
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
// check index and child order size // check index and child order size
if (graph->child_graph_order().size() <= static_cast<size_t>(order_index)) { if (child_graph_order.size() <= IntToSize(order_index)) {
MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size "
<< graph->child_graph_order().size() << " goto index " << order_index; << child_graph_order.size() << " goto index " << order_index;
} }
if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) { if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) {
@ -443,7 +441,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
label_index = label_goto_index; label_index = label_goto_index;
} }
// get start_label_set_index of child graph // get start_label_set_index of child graph
auto child_graph = graph->child_graph_order()[order_index]; auto child_graph = child_graph_order[order_index];
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);
auto start_label_set = child_graph->get_start_label(); auto start_label_set = child_graph->get_start_label();
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) { if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) {
@ -468,8 +466,7 @@ std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t
uint32_t index = 0; uint32_t index = 0;
for (auto &node : nodes) { for (auto &node : nodes) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
label_map[index] = node; label_map[index++] = node;
++index;
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) { if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
@ -479,8 +476,7 @@ std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t
std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList)); std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
label_switch_map.insert({node, label_list}); label_switch_map.insert({node, label_list});
for (size_t i = 0; i < label_list.size(); ++i) { for (size_t i = 0; i < label_list.size(); ++i) {
label_map[index] = node; label_map[index++] = node;
++index;
} }
} }
} }

@ -49,16 +49,15 @@ class AscendControlParser {
NotNull<std::set<KernelGraphPtr> *> memo); NotNull<std::set<KernelGraphPtr> *> memo);
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo); const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node); static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
static void LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph, static void LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param); NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param);
static NotNull<AnfNodePtr> GetRealInput(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> param);
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static CNodePtr GetNextRealKernel(std::vector<CNodePtr> list, size_t start); static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start);
// root graph order // root graph order
static std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> GetLabelNode( static std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> GetLabelNode(
@ -67,20 +66,7 @@ class AscendControlParser {
NotNull<KernelGraphPtr> graph); NotNull<KernelGraphPtr> graph);
static std::vector<CNodePtr> RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto, static std::vector<CNodePtr> RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1;
static constexpr size_t kCNodeSwitchCond = 1;
static constexpr size_t kCNodeSwitchTrue = 2;
static constexpr size_t kCNodeSwitchFalse = 3;
static constexpr size_t kCNodeSwitchLength = 4;
static constexpr size_t kCNodePartialLength = 2;
static constexpr size_t kCNodePartialFunc = 1;
static constexpr size_t kCNodeSwitchLayerCond = 1;
static constexpr size_t kCNodeSwitchLayerBranch = 2;
static constexpr size_t kCNodeSwitchLayerLength = 3;
}; };
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

@ -256,7 +256,6 @@ static void UpdateRealInput(KernelGraph *graph) {
void RecurseToUpdateCallRealInput(KernelGraph *graph) { void RecurseToUpdateCallRealInput(KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "start graph id:" << graph->graph_id(); MS_LOG(INFO) << "start graph id:" << graph->graph_id();
graph->UpdateCallRealInput();
for (auto &child_graph : graph->child_graph_order()) { for (auto &child_graph : graph->child_graph_order()) {
if (child_graph == graph->parent_graph()) { if (child_graph == graph->parent_graph()) {
MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() MS_LOG(INFO) << "Child graph:" << child_graph->graph_id()
@ -265,6 +264,8 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) {
} }
RecurseToUpdateCallRealInput(child_graph.get()); RecurseToUpdateCallRealInput(child_graph.get());
} }
// this action should from bottom to top
graph->UpdateCallRealInput();
} }
} // namespace } // namespace
@ -280,27 +281,20 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph); auto graph = ConstructKernelGraph(func_graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// split switch // split switch
SplitGraphs(graph); SplitGraphs(graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// insert goto labels and label_sets // insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(graph)); LinkChildGraphs(NOT_NULL(graph));
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// resource initialize // resource initialize
InitRuntimeResource(); InitRuntimeResource();
// assign label // assign label
AssignLabel(NOT_NULL(graph)); AssignLabel(NOT_NULL(graph));
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// recurse compile child graph // recurse compile child graph
RecurseCompileGraph(graph); RecurseCompileGraph(graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// root graph valiate,include genearte execute order and so on // root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(NOT_NULL(graph)); RootGraphExecutorValidate(NOT_NULL(graph));
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// adjust kernel // adjust kernel
AdjustKernel(graph); AdjustKernel(graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// assign stream // assign stream
AssignStream(graph); AssignStream(graph);
// build kernel // build kernel
@ -313,7 +307,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
LoadTask(graph); LoadTask(graph);
// return the graph id to backend // return the graph id to backend
auto graph_id = graph->graph_id(); auto graph_id = graph->graph_id();
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id; return graph_id;
} }

@ -606,10 +606,6 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
break; break;
} }
} }
MS_LOG(INFO) << "Inputs of graph id:" << graph_id();
for (size_t i = 0; i < inputs().size(); i++) {
MS_LOG(INFO) << "[" << i << "]:" << inputs()[i]->DebugString();
}
} }
// update front to backend map // update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node); FrontBackendlMapUpdate(old_anf_node, new_anf_node);
@ -713,6 +709,9 @@ void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "paramter: " << parameter->DebugString() MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " insert real input:" << new_real_input->DebugString(); << " insert real input:" << new_real_input->DebugString();
(void)real_inputs.insert(new_real_input); (void)real_inputs.insert(new_real_input);
if (new_real_input->isa<Parameter>()) {
ReplaceNode(parameter, new_real_input);
}
} }
} }
} }

Loading…
Cancel
Save