!2761 Record unreused arg in kernel graph

Merge pull request !2761 from chenfei_mindspore/split-real-inputs-to-reuse-args-and-not-reuse-args
pull/2761/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f6b6ef2796

@ -102,7 +102,7 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
memo->insert(graph.get());
MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString();
graph->SetExecOrderByDefault();
auto nodes = graph->execution_order();
auto end_goto = graph->get_end_goto();
if (end_goto != nullptr) {
@ -128,6 +128,7 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
for (auto &cg : graph->child_graph_order()) {
AssignLabelForGotoSwitch(NOT_NULL(cg), memo);
}
graph->SetExecOrderByDefault();
}
void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {

@ -199,7 +199,6 @@ class AnfRuntimeAlgorithm {
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph);
// get fix output precision of cnode.
static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.

@ -19,6 +19,7 @@
#include <memory>
#include "session/anf_runtime_algorithm.h"
#include "utils/union_find_set.h"
#include "device/ascend/ascend_label_assign.h"
static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1;
@ -35,17 +36,25 @@ namespace mindspore {
namespace session {
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
auto &nodes = parent_graph->execution_order();
CNodePtr last_jump_node = nullptr;
for (auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
return node;
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
return node;
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
if (child_graph->get_start_label() == node->input(kCNodeCallArg)) {
return node;
}
last_jump_node = node;
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) {
return node;
}
last_jump_node = node;
}
}
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
return nullptr;
if (last_jump_node == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
}
return last_jump_node;
}
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
@ -90,6 +99,9 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
if (!arg->isa<Parameter>()) {
continue;
}
if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) {
continue;
}
union_find_set->Union(arg, para);
}
}
@ -133,24 +145,28 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
}
}
static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr key,
const std::set<AnfNodePtr> &parameter_reuse_set) {
AnfNodePtr main_parameter = key;
std::set<AnfNodePtr> root_inputs_set;
const auto &root_inputs_vector = root_kg->inputs();
root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end());
for (auto &node : parameter_reuse_set) {
if (root_inputs_set.find(node) != root_inputs_set.end()) {
main_parameter = node;
break;
}
}
return main_parameter;
}
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) {
auto parameter_reuse_sets = parameter_set->GetSets();
for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) {
if (parameter_reuse_set.size() <= 1) {
continue;
}
AnfNodePtr main_parameter = key;
std::set<AnfNodePtr> root_inputs_set;
const auto &root_inputs_vector = root_kg->inputs();
root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end());
for (auto &node : parameter_reuse_set) {
if (root_inputs_set.find(node) != root_inputs_set.end()) {
main_parameter = node;
break;
}
}
auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set);
std::set<KernelGraphPtr> memo;
RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo));
}
@ -168,6 +184,7 @@ CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
std::set<KernelGraphPtr> memo;
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg);
std::map<uint32_t, KernelGraphPtr> graph_id_map;
for (auto &g : memo) {
MS_EXCEPTION_IF_NULL(g);
@ -177,12 +194,13 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
}
graph_id_map[g->graph_id()] = g;
}
// Insert Assign
ChildGraphDataAssign(graph_id_map);
// Make UnionFindSet
UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg);
// Reuse Parameter
ReuseParameter(kg, NOT_NULL(&parameter_set));
// Insert Assign
ChildGraphDataAssign(graph_id_map);
}
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
@ -193,6 +211,7 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
for (auto &iter : graph_id_map) {
auto &kg = iter.second;
MS_LOG(INFO) << "Data assign graph:" << kg->graph_id();
MS_EXCEPTION_IF_NULL(kg);
std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo;
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
@ -206,8 +225,14 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
} else {
memo.emplace(parameter, arg);
}
if (arg->isa<Parameter>()) {
auto unreuse_args_map = kg->unreuse_args();
auto unreuse_arg_iter = unreuse_args_map.find(arg);
if (unreuse_arg_iter == unreuse_args_map.end()) {
MS_EXCEPTION_IF_NULL(arg);
MS_EXCEPTION_IF_NULL(parameter);
if (!arg->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << ".";
}
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
<< ", arg:" << arg->DebugString();
continue;
@ -220,6 +245,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
NOT_NULL(parameter));
}
}
kg->SetExecOrderByDefault();
}
}
@ -353,7 +379,6 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
// 5 recurse sub graph
CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo);
new_inputs.push_back(sub_label);
new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
cur_node->set_inputs(new_inputs);
cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
@ -394,7 +419,6 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
}
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
@ -477,6 +501,16 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
if (assign_node != nullptr) {
auto jump_node = GetJumpNode(from_graph, to_graph);
const auto &from_graph_exe_order = from_graph->execution_order();
auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node);
if (jump_node_iter == from_graph_exe_order.end()) {
MS_EXCEPTION_IF_NULL(jump_node);
MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id();
}
// insert assign between jump_node -1 and jump_node
if (jump_node_iter != from_graph_exe_order.begin()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
}
if (jump_node != nullptr) {
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
}
@ -501,8 +535,6 @@ AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg,
auto assign_node = kg->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(assign_node);
assign_node->set_abstract(to->abstract());
// append the assign at the end of from graph
InsertDependToGraph(kg, NOT_NULL(assign_node));
return assign_node;
}
@ -527,7 +559,6 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
std::vector<CNodePtr> execution_order;
uint32_t child_order_index = 0;
for (auto &node : cnodes) {
execution_order.push_back(node);
if (node == graph->get_end_goto()) {

@ -23,6 +23,7 @@
#include "session/kernel_graph.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
#include "utils/union_find_set.h"
namespace mindspore {
namespace session {

@ -202,7 +202,8 @@ static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePt
}
static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters, const std::vector<AnfNodePtr> &args,
KernelGraph *child_graph) {
const KernelGraphPtr &graph, KernelGraphPtr child_graph,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_EXCEPTION_IF_NULL(child_graph);
MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id();
if (args.empty()) {
@ -214,18 +215,25 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters,
}
child_graph->SetExecOrderByDefault();
for (size_t i = 0; i < parameters.size(); i++) {
MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]"
<< args[i]->DebugString();
if (args[i] == parameters[i]) {
child_graph->SetRealInput(parameters[i], args[i]);
MS_LOG(INFO) << "Parameter and arg are same.";
continue;
}
child_graph->SetRealInput(parameters[i], args[i]);
if (memo->find(child_graph) != memo->end() || !args[i]->isa<Parameter>()) {
MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id();
child_graph->AddUnreuseArgs(args[i], graph);
}
}
}
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) {
static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_EXCEPTION_IF_NULL(memo.get());
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
@ -235,7 +243,7 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) {
std::vector<AnfNodePtr> real_args =
std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end());
std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs();
BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get());
BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo);
if (split_flag) {
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
}
@ -256,8 +264,8 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) {
}
return ret;
};
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo);
BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo);
}
}
}
@ -306,8 +314,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
LinkChildGraphs(NOT_NULL(root_graph));
// resource initialize
InitRuntimeResource();
// assign label
AssignLabel(NOT_NULL(root_graph));
// recurse compile child root_graph
std::set<KernelGraphPtr> memo;
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
@ -665,12 +671,6 @@ void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const {
MS_LOG(INFO) << "Finish!";
}
void AscendSession::AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const {
MS_LOG(INFO) << "Start!";
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph);
MS_LOG(INFO) << "Finish!";
}
void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
struct timeval start_time, end_time;
@ -1582,14 +1582,17 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
auto input = cnode->inputs()[input_idx];
MS_EXCEPTION_IF_NULL(input);
AnfNodePtr new_parameter = nullptr;
// check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one
// parameter in graph inputs and one arg in call node
auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input);
if (call_input_it != call_node_inputs.end()) {
cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]);
continue;
}
// value node consider move to new graph
if (input->isa<ValueNode>()) {
cnode->set_input(input_idx, input);
continue;
} else if (input->isa<Parameter>()) {
// parameter reuse and should attention mulptiple use of one parameter
cnode->set_input(input_idx, input);
new_parameter = input;
} else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
// if is cnode and not in current child graph
new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
@ -1598,12 +1601,8 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
// if is a cnode and in current graph
continue;
}
// if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node
// args
if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) {
new_graph_inputs.push_back(new_parameter);
call_node_inputs.push_back(input);
}
new_graph_inputs.push_back(new_parameter);
call_node_inputs.push_back(input);
}
}
// set graph inputs of new graph
@ -1631,7 +1630,7 @@ void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence
auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first;
if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) {
SplitGraph(root_graph, {prim::kPrimReturn});
SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo));
for (auto &child_graph : root_graph->child_graph_order()) {
RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo));
}
@ -1672,7 +1671,8 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
return new_call;
}
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id();
bool split_flag = false;
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
@ -1710,14 +1710,13 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
split_flag = true;
}
AscendControlParser::UpdateChildGraphOrder(graph);
UpdateRealInput(graph, split_flag);
UpdateRealInput(graph, split_flag, memo);
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end";
// recurse to split child graph
}
void AscendSession::RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) {
memo->insert(graph.get());
SplitGraph(graph, {prim::kPrimCall});
SplitGraph(graph, {prim::kPrimCall}, memo);
for (auto &child_graph : graph->child_graph_order()) {
if (memo->find(child_graph) == memo->end()) {
RecurseSplitGraph(NOT_NULL(child_graph), memo);

@ -77,7 +77,6 @@ class AscendSession : public SessionBasic {
void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const;
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
@ -100,7 +99,8 @@ class AscendSession : public SessionBasic {
void SetFinalGraphOutput(const ValuePtr &value);
void SetFinalGraphOutput(const VectorRef &vec_output);
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims);
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims,
const NotNull<std::set<KernelGraphPtr> *> memo);
// split graphs with recurse from root graph
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);

@ -103,6 +103,23 @@ AnfNodePtr MakeValueNode(const AnfNodePtr &node) {
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
return new_value_node;
}
bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
if (left == right) {
return true;
}
if (left == nullptr || right == nullptr) {
return false;
}
if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) {
return false;
}
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) {
return AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) ==
AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex);
}
return false;
}
} // namespace
std::vector<AnfNodePtr> KernelGraph::outputs() const {
auto graph_output = output();
@ -219,6 +236,19 @@ void KernelGraph::SetExecOrderByDefault() {
if (node == start_label_ || node == end_goto_) {
continue;
}
if (IsSameLabel(node, end_goto_)) {
end_goto_ = node;
MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id();
continue;
}
if (IsSameLabel(node, start_label_)) {
start_label_ = node;
MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id();
continue;
}
re_order.push_back(node);
}
if (end_goto_ != nullptr) {
@ -751,10 +781,9 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
}
// update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
// update output depend relations
node_output_edges_[new_anf_node.get()] = it->second;
(void)node_output_edges_.erase(old_anf_node);
}
// if change the ir of graph, regenerate execution order of graph
SetExecOrderByDefault();
// update graph inputs in child graph
auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(),
[&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
@ -770,7 +799,7 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
return n.first == new_anf_node.get();
});
if (iter != real_inputs_.end()) {
MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited.";
MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited.";
iter->second = old_args;
} else {
real_inputs_.emplace_back(new_anf_node, old_args);
@ -827,6 +856,10 @@ void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &ar
}
}
void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph) {
unreuse_args_[arg] = from_graph;
}
void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "Update graph id: " << graph_id_;
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map;
@ -839,6 +872,17 @@ void KernelGraph::UpdateCallRealInput() {
// if real input is a call node ,find the child graph output act as the new real input
auto tmp_real_input = GetCallRealOutputs(real_input);
std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs));
// replace the call in unreuse_args_
auto unreuse_arg_it = unreuse_args_.find(real_input);
if (unreuse_arg_it != unreuse_args_.end()) {
auto old_graph = unreuse_arg_it->second;
for (auto new_real_input : new_real_inputs) {
// if call reference graph output is parameter, it will be allowed to reuse
if (!new_real_input->isa<Parameter>()) {
unreuse_args_[new_real_input] = old_graph;
}
}
}
}
real_inputs_map.emplace_back(parameter, new_real_inputs);
}

@ -130,6 +130,9 @@ class KernelGraph : public FuncGraph {
// get real inputs
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; }
void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg);
// mark unreused args
void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph);
const std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> &unreuse_args() const { return unreuse_args_; }
// used to dump ir
std::string ToString() const override;
// update the real input if the node is a call
@ -205,6 +208,7 @@ class KernelGraph : public FuncGraph {
std::shared_ptr<KernelGraph> parent_graph_;
// record real parameters,inputs_ is the formal parameters
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_;
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> unreuse_args_;
CNodePtr start_label_;
CNodePtr end_goto_;

@ -99,6 +99,19 @@ class ControlIfbyIfbyIf(nn.Cell):
return out
class ControlSimpleWhile(nn.Cell):
def __init__(self):
super().__init__()
self.addn = op.AddN()
def construct(self, x, y, input_data):
out = input_data
while x:
out = self.addn([input_data, input_data, input_data])
x = y
return out
class ControlMixedWhileIf(nn.Cell):
def __init__(self):
super().__init__()
@ -204,6 +217,22 @@ def test_if_by_if_by_if():
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_simple_while():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
x = np.array(True).astype(np.bool)
y = np.array(False).astype(np.bool)
input_shape = (127, 7, 53, 31)
input_data = np.random.randn(*input_shape).astype(np.float32)
net = ControlSimpleWhile()
output = net(Tensor(x), Tensor(y), Tensor(input_data))
expect = input_data * 3
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training

Loading…
Cancel
Save