!5585 Fix output device address setting for inputs of depend node

Merge pull request !5585 from YuJianfeng/depend_valid_input
pull/5585/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit e17eea3485

@ -35,7 +35,7 @@
namespace mindspore { namespace mindspore {
namespace session { namespace session {
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (!anf->isa<Parameter>()) { if (!anf->isa<Parameter>()) {
@ -49,7 +49,7 @@ ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf,
ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
TraceManager::EndTrace(); TraceManager::EndTrace();
graph_inputs->push_back(new_parameter); graph_inputs->push_back(new_parameter);
valid_inputs->push_back(valid_input); valid_inputs->push_back(true);
return new_parameter; return new_parameter;
} }

@ -37,7 +37,7 @@ class CPUSession : public SessionBasic {
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) override; std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) override;
protected: protected:
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override;
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph); void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
private: private:

@ -395,8 +395,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
} }
} }
std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters; std::vector<AnfNodePtr> parameters;
@ -418,7 +417,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr
parameter->set_abstract(abstract); parameter->set_abstract(abstract);
auto new_parameter = graph->NewParameter(parameter); auto new_parameter = graph->NewParameter(parameter);
parameters.push_back(new_parameter); parameters.push_back(new_parameter);
valid_inputs->push_back(valid_input); valid_inputs->push_back(true);
graph_inputs->push_back(new_parameter); graph_inputs->push_back(new_parameter);
}; };
for (const auto &out_node : pre_graph_out) { for (const auto &out_node : pre_graph_out) {
@ -442,8 +441,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr
return parameters; return parameters;
} }
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<Parameter>()) { if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
@ -471,15 +469,15 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
graph_inputs->push_back(new_parameter); graph_inputs->push_back(new_parameter);
valid_inputs->push_back(valid_input); valid_inputs->push_back(true);
return new_parameter; return new_parameter;
} }
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph); auto parameters = CreateParameterFromTuple(anf, graph);
if (parameters.empty()) { if (parameters.empty()) {
MS_LOG(INFO) << "Empty parameter from cnode"; MS_LOG(INFO) << "Empty parameter from cnode";
return nullptr; return nullptr;
@ -495,14 +493,11 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool
return make_tuple; return make_tuple;
} }
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) { std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(from_other_graph);
MS_EXCEPTION_IF_NULL(other_graph_cnode); MS_EXCEPTION_IF_NULL(other_graph_cnode);
*from_other_graph = false;
// get primitive of old node // get primitive of old node
std::vector<AnfNodePtr> cnode_inputs; std::vector<AnfNodePtr> cnode_inputs;
auto prim = AnfAlgo::GetCNodePrimitive(cnode); auto prim = AnfAlgo::GetCNodePrimitive(cnode);
@ -544,7 +539,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
} }
continue; continue;
} else if (anf->isa<Parameter>()) { } else if (anf->isa<Parameter>()) {
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); auto new_parameter = CreateNewParameterFromParameter(anf, graph);
cnode_inputs.push_back(new_parameter); cnode_inputs.push_back(new_parameter);
if (GetGraphIdByNode(anf) == kInvalidGraphId) { if (GetGraphIdByNode(anf) == kInvalidGraphId) {
graph->FrontBackendlMapAdd(anf, new_parameter); graph->FrontBackendlMapAdd(anf, new_parameter);
@ -558,9 +553,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
} else if (optimize_control_depend) { } else if (optimize_control_depend) {
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
} else { } else {
*from_other_graph = true;
// the input node is a cnode from other graph // the input node is a cnode from other graph
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
if (parameter_from_cnode == nullptr) { if (parameter_from_cnode == nullptr) {
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx))); parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx)));
} }
@ -587,7 +581,7 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
} else { } else {
KernelGraphPtr kernel_graph = NewKernelGraph(); KernelGraphPtr kernel_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get()); auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), kernel_graph.get());
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())); auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
auto return_node = kernel_graph->NewCNode({primitive, parameter}); auto return_node = kernel_graph->NewCNode({primitive, parameter});
kernel_graph->set_return(return_node); kernel_graph->set_return(return_node);
@ -806,7 +800,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
auto graph = NewKernelGraph(); auto graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); MS_LOG(INFO) << "Create graph: " << graph->graph_id();
size_t from_other_graph_depend_num = 0;
for (const auto &node : lst) { for (const auto &node : lst) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
@ -816,16 +809,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// create a new cnode object // create a new cnode object
bool from_other_graph = false; auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
// only first depend from other graph can create
bool valid_input = true;
if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
valid_input = false;
}
auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
from_other_graph_depend_num++;
}
MS_EXCEPTION_IF_NULL(new_cnode); MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract()); new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope()); new_cnode->set_scope(cnode->scope());

@ -100,7 +100,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph, std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph); std::vector<KernelGraphPtr> *all_out_graph);
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph); CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph);
@ -153,11 +153,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::vector<int> &tensors_mask); const std::vector<int> &tensors_mask);
// create a new kernel graph and update the graph sum // create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph(); KernelGraphPtr NewKernelGraph();
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);
virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph);
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph);
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph); void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph);
void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter); void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter);
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list); AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);

Loading…
Cancel
Save