|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -84,7 +84,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend, prim::kPrimLoad};
|
|
|
|
|
std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimLoad};
|
|
|
|
|
for (auto &item : adapter_convert_ops) {
|
|
|
|
|
if (IsPrimitiveCNode(node, item)) {
|
|
|
|
|
return true;
|
|
|
|
@ -243,8 +243,7 @@ CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t sw
|
|
|
|
|
return merge_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data))
|
|
|
|
|
// control_depend(output_node, square_op)
|
|
|
|
|
// merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data))
|
|
|
|
|
AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node,
|
|
|
|
|
int64_t switch_idx) {
|
|
|
|
|
tensor::TensorPtr const_data = GetConstData();
|
|
|
|
@ -259,54 +258,21 @@ AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr
|
|
|
|
|
SetSquareOp(switch_idx, square_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), square_op, output_node};
|
|
|
|
|
auto depend_cnode = graph->NewCNode(inputs);
|
|
|
|
|
if (!manager->Replace(square_op, depend_cnode)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << square_op->DebugString() << ", replace node failed.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr merge_op = GetMergeOp(switch_idx);
|
|
|
|
|
if (merge_op == nullptr) {
|
|
|
|
|
merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op);
|
|
|
|
|
SetMergeOp(switch_idx, merge_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op};
|
|
|
|
|
auto control_depend_op = graph->NewCNode(control_depend_nodes);
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op};
|
|
|
|
|
auto depend_op = graph->NewCNode(depend_nodes);
|
|
|
|
|
|
|
|
|
|
return depend_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// construct a merge output and add dependency with the netoutput node from control_depend
|
|
|
|
|
// we need to reserve the control_depend node, besides the generated merge node and control_depend node
|
|
|
|
|
CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
|
|
|
|
const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst,
|
|
|
|
|
int64_t switch_idx) {
|
|
|
|
|
auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
|
|
|
|
auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
|
|
|
|
std::vector<int64_t> shp = {1};
|
|
|
|
|
tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp);
|
|
|
|
|
auto *val = static_cast<int64_t *>(const_data->data_c());
|
|
|
|
|
*val = 0;
|
|
|
|
|
// for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same
|
|
|
|
|
// switch the other use the opposite
|
|
|
|
|
auto ctrl_data = NewValueNode(const_data);
|
|
|
|
|
auto oppsite_ctrl_data = NewValueNode(const_data);
|
|
|
|
|
auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx);
|
|
|
|
|
auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx);
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> square_nodes{NewValueNode(PrimSquare), ctrl_node};
|
|
|
|
|
auto square_op = graph->NewCNode(square_nodes);
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> merge_nodes;
|
|
|
|
|
merge_nodes.push_back(NewValueNode(PrimMerge));
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node};
|
|
|
|
|
merge_nodes.push_back(graph->NewCNode(make_tuple_nodes));
|
|
|
|
|
auto merge_output = graph->NewCNode(merge_nodes);
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op};
|
|
|
|
|
auto cond_dep_output = graph->NewCNode(control_depend_nodes);
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output,
|
|
|
|
|
cond_dep_output};
|
|
|
|
|
return graph->NewCNode(depended_make_tuple_nodes);
|
|
|
|
|
return merge_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generate switch nodes for true graph node inputs
|
|
|
|
@ -321,26 +287,12 @@ AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNod
|
|
|
|
|
return GenerateSwitchDependNode(graph, cond, data, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generate switch nodes for true graph node inputs
|
|
|
|
|
CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
|
|
|
|
const AnfNodePtr &con_input, const AnfNodePtr &output) {
|
|
|
|
|
// for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch
|
|
|
|
|
return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generate switch nodes for false graph node inputs
|
|
|
|
|
CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
|
|
|
|
const AnfNodePtr &con_input, const AnfNodePtr &output) {
|
|
|
|
|
// for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch
|
|
|
|
|
return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// to judge if the node used in ControlDepend is a net output node
|
|
|
|
|
// to judge if the node used in Depend is a net output node
|
|
|
|
|
bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
|
|
|
|
|
auto uses = manager->node_users()[node];
|
|
|
|
|
bool is_output_node = true;
|
|
|
|
|
for (auto &item : uses) {
|
|
|
|
|
if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) {
|
|
|
|
|
if (IsPrimitiveCNode(item.first, prim::kPrimDepend)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
is_output_node = false;
|
|
|
|
@ -353,8 +305,7 @@ bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node)
|
|
|
|
|
void GenerateReplNodeForDependMakeTuple(
|
|
|
|
|
const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
|
|
|
|
const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node,
|
|
|
|
|
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func,
|
|
|
|
|
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
|
|
|
|
|
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph->manager());
|
|
|
|
|
|
|
|
|
|
auto make_tuple_inputs = depended_node->cast<CNodePtr>()->inputs();
|
|
|
|
@ -368,26 +319,6 @@ void GenerateReplNodeForDependMakeTuple(
|
|
|
|
|
new_make_tuple_nodes.push_back(depended_tuple_input_node);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsPrimitiveCNode(depended_tuple_input_node->cast<CNodePtr>(), prim::kPrimControlDepend)) {
|
|
|
|
|
// only when the control depend input is not square op (the op to use as merge output)
|
|
|
|
|
auto control_inputs = depended_tuple_input_node->cast<CNodePtr>()->inputs();
|
|
|
|
|
if (control_inputs.size() != 3) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size();
|
|
|
|
|
}
|
|
|
|
|
// control inputs: primitive, src, dst
|
|
|
|
|
auto dst_node = control_inputs[2];
|
|
|
|
|
if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) {
|
|
|
|
|
auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(gen_node);
|
|
|
|
|
auto tuple_inputs = gen_node->inputs();
|
|
|
|
|
// add depended tuple inputs to new_make_tuple directly
|
|
|
|
|
for (size_t i = 1; i < tuple_inputs.size(); i++) {
|
|
|
|
|
new_make_tuple_nodes.push_back(tuple_inputs[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
replace_make_tuple = true;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) {
|
|
|
|
|
auto gen_node = generate_func(graph, cond, depended_tuple_input_node);
|
|
|
|
@ -408,8 +339,7 @@ void GenerateReplNodeForDependMakeTuple(
|
|
|
|
|
void GenerateRepDepend(
|
|
|
|
|
const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
|
|
|
|
const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node,
|
|
|
|
|
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func,
|
|
|
|
|
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
|
|
|
|
|
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
|
|
|
|
|
auto inputs = node->inputs();
|
|
|
|
|
if (inputs.size() != 3) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node].";
|
|
|
|
@ -422,19 +352,7 @@ void GenerateRepDepend(
|
|
|
|
|
new_depened_inputs.push_back(inputs[1]);
|
|
|
|
|
// depended node should be make_tuple or a single depended node
|
|
|
|
|
if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) {
|
|
|
|
|
GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func);
|
|
|
|
|
} else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) {
|
|
|
|
|
// only when the control depend input is not square op (the op to use as merge output)
|
|
|
|
|
auto control_inputs = depended_node->cast<CNodePtr>()->inputs();
|
|
|
|
|
// control inputs: primitive, src, dst
|
|
|
|
|
if (control_inputs.size() != 3) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size();
|
|
|
|
|
}
|
|
|
|
|
auto dst_node = control_inputs[2];
|
|
|
|
|
if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) {
|
|
|
|
|
auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node);
|
|
|
|
|
(*repl_node)[depended_node] = gen_node;
|
|
|
|
|
}
|
|
|
|
|
GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func);
|
|
|
|
|
} else {
|
|
|
|
|
// Check if there is only single user for depend_node.
|
|
|
|
|
if (graph->manager()->node_users()[depended_node].size() == 1) {
|
|
|
|
@ -448,11 +366,9 @@ void GenerateRepDepend(
|
|
|
|
|
|
|
|
|
|
// generate depend node for netoutput node, to resolve the stream synchronize problem of ge
|
|
|
|
|
// traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const)
|
|
|
|
|
// and add control_depend of graph output node and square node.
|
|
|
|
|
FuncGraphPtr TransformGraphDependNode(
|
|
|
|
|
const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
|
|
|
|
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func,
|
|
|
|
|
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
|
|
|
|
|
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func) {
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
|
|
|
|
@ -478,7 +394,7 @@ FuncGraphPtr TransformGraphDependNode(
|
|
|
|
|
if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func);
|
|
|
|
|
GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ResetSharedOp();
|
|
|
|
@ -494,12 +410,12 @@ FuncGraphPtr TransformGraphDependNode(
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) {
|
|
|
|
|
(void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode);
|
|
|
|
|
return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode);
|
|
|
|
|
return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) {
|
|
|
|
|
(void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode);
|
|
|
|
|
return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode);
|
|
|
|
|
return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// judge if the true and false graph output is compatible(they shall have same tuple size)
|
|
|
|
|