diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc index 7af615f448..2db81a1725 100644 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc @@ -102,7 +102,7 @@ static void AssignLabelForGotoSwitch(NotNullinsert(graph.get()); MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); - graph->SetExecOrderByDefault(); + auto nodes = graph->execution_order(); auto end_goto = graph->get_end_goto(); if (end_goto != nullptr) { @@ -128,6 +128,7 @@ static void AssignLabelForGotoSwitch(NotNullchild_graph_order()) { AssignLabelForGotoSwitch(NOT_NULL(cg), memo); } + graph->SetExecOrderByDefault(); } void AscendLabelAssign::AssignLabel(NotNull> graph) { diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index ae8b450e2f..8205619793 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -199,7 +199,6 @@ class AnfRuntimeAlgorithm { static bool IsScalarInput(const CNodePtr &cnode, size_t index); static bool IsScalarOutput(const CNodePtr &cnode, size_t index); static void ReorderExecList(NotNull *> node_list); - static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph); // get fix output precision of cnode. static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); // get fix output precision from prev node, input_idx is the input index of current node related to prev node. diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 166d4cc97a..18e71d74e3 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -19,6 +19,7 @@ #include #include "session/anf_runtime_algorithm.h" #include "utils/union_find_set.h" +#include "device/ascend/ascend_label_assign.h" static constexpr size_t kCNodePrim = 0; static constexpr size_t kCNodeCallArg = 1; @@ -35,17 +36,25 @@ namespace mindspore { namespace session { static CNodePtr GetJumpNode(NotNull parent_graph, NotNull child_graph) { auto &nodes = parent_graph->execution_order(); + CNodePtr last_jump_node = nullptr; for (auto &node : nodes) { - if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) { - return node; - } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) && - (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || - child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) { - return node; + if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) { + if (child_graph->get_start_label() == node->input(kCNodeCallArg)) { + return node; + } + last_jump_node = node; + } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) { + if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || + child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) { + return node; + } + last_jump_node = node; } } - MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); - return nullptr; + if (last_jump_node == nullptr) { + MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); + } + return last_jump_node; } static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, @@ -90,6 +99,9 @@ static void UnionParentParameter(NotNull kg, const NotNullisa()) { continue; } + if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) { + continue; + } union_find_set->Union(arg, para); } } @@ -133,24 +145,28 @@ static void RecursiveReplaceNode(NotNull kg, NotNull } } +static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNodePtr key, + const std::set ¶meter_reuse_set) { + AnfNodePtr main_parameter = key; + std::set root_inputs_set; + const auto &root_inputs_vector = root_kg->inputs(); + root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); + for (auto &node : parameter_reuse_set) { + if (root_inputs_set.find(node) != root_inputs_set.end()) { + main_parameter = node; + break; + } + } + return main_parameter; +} + static void ReuseParameter(NotNull root_kg, NotNull *> parameter_set) { auto parameter_reuse_sets = parameter_set->GetSets(); for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { if (parameter_reuse_set.size() <= 1) { continue; } - - AnfNodePtr main_parameter = key; - std::set root_inputs_set; - const auto &root_inputs_vector = root_kg->inputs(); - root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); - for (auto &node : parameter_reuse_set) { - if (root_inputs_set.find(node) != root_inputs_set.end()) { - main_parameter = node; - break; - } - } - + auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set); std::set memo; RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); } @@ -168,6 +184,7 @@ CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { void AscendControlParser::LinkGraph(NotNull kg) { std::set memo; (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); + device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); std::map graph_id_map; for (auto &g : memo) { MS_EXCEPTION_IF_NULL(g); @@ -177,12 +194,13 @@ void AscendControlParser::LinkGraph(NotNull kg) { } graph_id_map[g->graph_id()] = g; } + + // Insert Assign + ChildGraphDataAssign(graph_id_map); // Make UnionFindSet UnionFindSet parameter_set = MakeUnionFindSet(kg); // Reuse Parameter ReuseParameter(kg, NOT_NULL(¶meter_set)); - // Insert Assign - ChildGraphDataAssign(graph_id_map); } void AscendControlParser::ExecutorValidate(NotNull root_graph) { @@ -193,6 +211,7 @@ void AscendControlParser::ExecutorValidate(NotNull root_graph) { void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { for (auto &iter : graph_id_map) { auto &kg = iter.second; + MS_LOG(INFO) << "Data assign graph:" << kg->graph_id(); MS_EXCEPTION_IF_NULL(kg); std::set> memo; const std::vector>> &real_inputs = kg->real_inputs(); @@ -206,8 +225,14 @@ void AscendControlParser::ChildGraphDataAssign(const std::mapisa()) { + auto unreuse_args_map = kg->unreuse_args(); + auto unreuse_arg_iter = unreuse_args_map.find(arg); + if (unreuse_arg_iter == unreuse_args_map.end()) { + MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(parameter); + if (!arg->isa()) { + MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << "."; + } MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() << ", arg:" << arg->DebugString(); continue; @@ -220,6 +245,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::mapSetExecOrderByDefault(); } } @@ -353,7 +379,6 @@ void AscendControlParser::RecurseCall(NotNull kg, NotNullset_inputs(new_inputs); cur_node->set_abstract(nullptr); MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); @@ -394,7 +419,6 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNullset_inputs(new_switch_inputs); cur_node->set_abstract(nullptr); MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); @@ -477,6 +501,16 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull fr auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); if (assign_node != nullptr) { auto jump_node = GetJumpNode(from_graph, to_graph); + const auto &from_graph_exe_order = from_graph->execution_order(); + auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); + if (jump_node_iter == from_graph_exe_order.end()) { + MS_EXCEPTION_IF_NULL(jump_node); + MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id(); + } + // insert assign between jump_node -1 and jump_node + if (jump_node_iter != from_graph_exe_order.begin()) { + InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); + } if (jump_node != nullptr) { InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); } @@ -501,8 +535,6 @@ AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, auto assign_node = kg->NewCNode(inputs); MS_EXCEPTION_IF_NULL(assign_node); assign_node->set_abstract(to->abstract()); - // append the assign at the end of from graph - InsertDependToGraph(kg, NOT_NULL(assign_node)); return assign_node; } @@ -527,7 +559,6 @@ std::vector AscendControlParser::RecurseGraph(NotNull std::vector execution_order; uint32_t child_order_index = 0; - for (auto &node : cnodes) { execution_order.push_back(node); if (node == graph->get_end_goto()) { diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 2b383d7b14..82479fa527 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -23,6 +23,7 @@ #include "session/kernel_graph.h" #include "utils/base_ref.h" #include "utils/contract.h" +#include "utils/union_find_set.h" namespace mindspore { namespace session { diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 01b5691029..c6ef41f2ba 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -202,7 +202,8 @@ static std::vector> GetChildList(const std::vector ¶meters, const std::vector &args, - KernelGraph *child_graph) { + const KernelGraphPtr &graph, KernelGraphPtr child_graph, + const NotNull *> memo) { MS_EXCEPTION_IF_NULL(child_graph); MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id(); if (args.empty()) { @@ -214,18 +215,25 @@ static void BindCallArgsWithParameter(const std::vector ¶meters, } child_graph->SetExecOrderByDefault(); for (size_t i = 0; i < parameters.size(); i++) { + MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]" + << args[i]->DebugString(); if (args[i] == parameters[i]) { - child_graph->SetRealInput(parameters[i], args[i]); MS_LOG(INFO) << "Parameter and arg are same."; continue; } child_graph->SetRealInput(parameters[i], args[i]); + if (memo->find(child_graph) != memo->end() || !args[i]->isa()) { + MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id(); + child_graph->AddUnreuseArgs(args[i], graph); + } } } // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] -static void UpdateRealInput(NotNull graph, bool split_flag) { +static void UpdateRealInput(NotNull graph, bool split_flag, + const NotNull *> memo) { + MS_EXCEPTION_IF_NULL(memo.get()); auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); @@ -235,7 +243,7 @@ static void UpdateRealInput(NotNull graph, bool split_flag) { std::vector real_args = std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); std::vector child_inputs = child_graphs[0]->inputs(); - BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get()); + BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo); if (split_flag) { call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); } @@ -256,8 +264,8 @@ static void UpdateRealInput(NotNull graph, bool split_flag) { } return ret; }; - BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); - BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); + BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo); + BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo); } } } @@ -306,8 +314,6 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { LinkChildGraphs(NOT_NULL(root_graph)); // resource initialize InitRuntimeResource(); - // assign label - AssignLabel(NOT_NULL(root_graph)); // recurse compile child root_graph std::set memo; RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); @@ -665,12 +671,6 @@ void AscendSession::AssignStream(NotNull kernel_graph) const { MS_LOG(INFO) << "Finish!"; } -void AscendSession::AssignLabel(NotNull kernel_graph) const { - MS_LOG(INFO) << "Start!"; - device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); - MS_LOG(INFO) << "Finish!"; -} - void AscendSession::BuildKernel(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; struct timeval start_time, end_time; @@ -1582,14 +1582,17 @@ std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPt auto input = cnode->inputs()[input_idx]; MS_EXCEPTION_IF_NULL(input); AnfNodePtr new_parameter = nullptr; + // check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one + // parameter in graph inputs and one arg in call node + auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input); + if (call_input_it != call_node_inputs.end()) { + cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]); + continue; + } // value node consider move to new graph if (input->isa()) { cnode->set_input(input_idx, input); continue; - } else if (input->isa()) { - // parameter reuse and should attention mulptiple use of one parameter - cnode->set_input(input_idx, input); - new_parameter = input; } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { // if is cnode and not in current child graph new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); @@ -1598,12 +1601,8 @@ std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPt // if is a cnode and in current graph continue; } - // if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node - // args - if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) { - new_graph_inputs.push_back(new_parameter); - call_node_inputs.push_back(input); - } + new_graph_inputs.push_back(new_parameter); + call_node_inputs.push_back(input); } } // set graph inputs of new graph @@ -1631,7 +1630,7 @@ void AscendSession::SplitGraphs(NotNull root_graph) { // if root graph output is a call node ,the root graph is condition graph of 'if' sentence auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { - SplitGraph(root_graph, {prim::kPrimReturn}); + SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo)); for (auto &child_graph : root_graph->child_graph_order()) { RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); } @@ -1672,7 +1671,8 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull graph, return new_call; } -void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims) { +void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims, + const NotNull *> memo) { MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); bool split_flag = false; auto apply_list = GetCNodes(TopoSort(graph->get_return())); @@ -1710,14 +1710,13 @@ void AscendSession::SplitGraph(NotNull graph, const std::setgraph_id() << "] end"; - // recurse to split child graph } void AscendSession::RecurseSplitGraph(NotNull graph, const NotNull *> memo) { memo->insert(graph.get()); - SplitGraph(graph, {prim::kPrimCall}); + SplitGraph(graph, {prim::kPrimCall}, memo); for (auto &child_graph : graph->child_graph_order()) { if (memo->find(child_graph) == memo->end()) { RecurseSplitGraph(NOT_NULL(child_graph), memo); diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 4774015457..531860c379 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -77,7 +77,6 @@ class AscendSession : public SessionBasic { void AdjustKernel(const std::shared_ptr &kernel_graph) const; void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; void AssignStream(NotNull kernel_graph) const; - void AssignLabel(NotNull kernel_graph) const; void BuildKernel(const std::shared_ptr &kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const; void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; @@ -100,7 +99,8 @@ class AscendSession : public SessionBasic { void SetFinalGraphOutput(const ValuePtr &value); void SetFinalGraphOutput(const VectorRef &vec_output); - void SplitGraph(NotNull graph, const std::set &cut_prims); + void SplitGraph(NotNull graph, const std::set &cut_prims, + const NotNull *> memo); // split graphs with recurse from root graph void SplitGraphs(NotNull root_graph); void BackendOptimization(const std::vector &all_graphs); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 7b53afac2a..f79363fd03 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -103,6 +103,23 @@ AnfNodePtr MakeValueNode(const AnfNodePtr &node) { AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); return new_value_node; } + +bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { + if (left == right) { + return true; + } + if (left == nullptr || right == nullptr) { + return false; + } + if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) { + return false; + } + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) { + return AnfAlgo::GetNodeAttr(left, kAttrLabelIndex) == + AnfAlgo::GetNodeAttr(right, kAttrLabelIndex); + } + return false; +} } // namespace std::vector KernelGraph::outputs() const { auto graph_output = output(); @@ -219,6 +236,19 @@ void KernelGraph::SetExecOrderByDefault() { if (node == start_label_ || node == end_goto_) { continue; } + + if (IsSameLabel(node, end_goto_)) { + end_goto_ = node; + MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id(); + continue; + } + + if (IsSameLabel(node, start_label_)) { + start_label_ = node; + MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id(); + continue; + } + re_order.push_back(node); } if (end_goto_ != nullptr) { @@ -751,10 +781,9 @@ void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNullsecond; - (void)node_output_edges_.erase(old_anf_node); } + // if change the ir of graph, regenerate execution order of graph + SetExecOrderByDefault(); // update graph inputs in child graph auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), [&old_anf_node](const std::pair> &n) -> bool { @@ -770,7 +799,7 @@ void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNullDebugString() << " already exist in real inputs, will be rewrited."; + MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; iter->second = old_args; } else { real_inputs_.emplace_back(new_anf_node, old_args); @@ -827,6 +856,10 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar } } +void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph) { + unreuse_args_[arg] = from_graph; +} + void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "Update graph id: " << graph_id_; std::vector>> real_inputs_map; @@ -839,6 +872,17 @@ void KernelGraph::UpdateCallRealInput() { // if real input is a call node ,find the child graph output act as the new real input auto tmp_real_input = GetCallRealOutputs(real_input); std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); + // replace the call in unreuse_args_ + auto unreuse_arg_it = unreuse_args_.find(real_input); + if (unreuse_arg_it != unreuse_args_.end()) { + auto old_graph = unreuse_arg_it->second; + for (auto new_real_input : new_real_inputs) { + // if call reference graph output is parameter, it will be allowed to reuse + if (!new_real_input->isa()) { + unreuse_args_[new_real_input] = old_graph; + } + } + } } real_inputs_map.emplace_back(parameter, new_real_inputs); } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index d6a67f3f02..6861d43de0 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -130,6 +130,9 @@ class KernelGraph : public FuncGraph { // get real inputs const std::vector>> &real_inputs() const { return real_inputs_; } void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); + // mark unreused args + void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph); + const std::map> &unreuse_args() const { return unreuse_args_; } // used to dump ir std::string ToString() const override; // update the real input if the node is a call @@ -205,6 +208,7 @@ class KernelGraph : public FuncGraph { std::shared_ptr parent_graph_; // record real parameters,inputs_ is the formal parameters std::vector>> real_inputs_; + std::map> unreuse_args_; CNodePtr start_label_; CNodePtr end_goto_; diff --git a/tests/st/control/test_ascend_control_sink.py b/tests/st/control/test_ascend_control_sink.py index b38668cd25..39af571c14 100644 --- a/tests/st/control/test_ascend_control_sink.py +++ b/tests/st/control/test_ascend_control_sink.py @@ -99,6 +99,19 @@ class ControlIfbyIfbyIf(nn.Cell): return out +class ControlSimpleWhile(nn.Cell): + def __init__(self): + super().__init__() + self.addn = op.AddN() + + def construct(self, x, y, input_data): + out = input_data + while x: + out = self.addn([input_data, input_data, input_data]) + x = y + return out + + class ControlMixedWhileIf(nn.Cell): def __init__(self): super().__init__() @@ -204,6 +217,22 @@ def test_if_by_if_by_if(): assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_simple_while(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + x = np.array(True).astype(np.bool) + y = np.array(False).astype(np.bool) + input_shape = (127, 7, 53, 31) + input_data = np.random.randn(*input_shape).astype(np.float32) + net = ControlSimpleWhile() + output = net(Tensor(x), Tensor(y), Tensor(input_data)) + expect = input_data * 3 + assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training