|
|
|
@ -494,54 +494,52 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
|
|
|
|
void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
|
|
|
|
// get primitive of old node
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
|
|
|
|
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
|
if (prim != nullptr) {
|
|
|
|
|
// push attr to inputs[0] of new cnode
|
|
|
|
|
cnode_inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
|
|
|
|
|
cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
|
|
|
|
|
} else {
|
|
|
|
|
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fg);
|
|
|
|
|
auto new_fg = BasicClone(fg);
|
|
|
|
|
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
|
|
|
|
|
cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
|
|
|
|
auto origin_inputs = cnode->inputs();
|
|
|
|
|
bool optimize_depend = false;
|
|
|
|
|
bool optimize_control_depend = false;
|
|
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
|
|
|
|
|
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
|
|
|
|
|
optimize_depend = true;
|
|
|
|
|
}
|
|
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) {
|
|
|
|
|
optimize_control_depend = true;
|
|
|
|
|
}
|
|
|
|
|
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
|
|
|
|
|
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>();
|
|
|
|
|
bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3;
|
|
|
|
|
// if has multiple depends,only select first depend as parameter
|
|
|
|
|
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
|
|
|
|
|
auto anf = origin_inputs[input_idx];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(anf);
|
|
|
|
|
// anf has been created before
|
|
|
|
|
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
|
|
|
|
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
|
|
|
|
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
|
|
|
|
continue;
|
|
|
|
|
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
|
|
|
|
|
cnode_inputs.push_back((*other_graph_cnode)[anf]);
|
|
|
|
|
cnode_inputs->push_back((*other_graph_cnode)[anf]);
|
|
|
|
|
continue;
|
|
|
|
|
} else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
|
|
|
|
|
// if input is a value node,
|
|
|
|
|
auto new_value_node = CreateNewValueNode(anf, graph);
|
|
|
|
|
if (new_value_node != nullptr) {
|
|
|
|
|
cnode_inputs.emplace_back(new_value_node);
|
|
|
|
|
cnode_inputs->emplace_back(new_value_node);
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
} else if (anf->isa<Parameter>()) {
|
|
|
|
|
auto new_parameter = CreateNewParameterFromParameter(anf, graph);
|
|
|
|
|
cnode_inputs.push_back(new_parameter);
|
|
|
|
|
cnode_inputs->push_back(new_parameter);
|
|
|
|
|
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
|
|
|
|
|
graph->FrontBackendlMapAdd(anf, new_parameter);
|
|
|
|
|
} else {
|
|
|
|
@ -549,20 +547,31 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
|
|
|
|
|
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
|
|
|
|
|
cnode_inputs->push_back(origin_inputs[kRealInputIndexInDepend]);
|
|
|
|
|
continue;
|
|
|
|
|
} else if (optimize_control_depend) {
|
|
|
|
|
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
|
|
|
|
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
|
|
|
|
} else {
|
|
|
|
|
// the input node is a cnode from other graph
|
|
|
|
|
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
|
|
|
|
|
if (parameter_from_cnode == nullptr) {
|
|
|
|
|
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx)));
|
|
|
|
|
}
|
|
|
|
|
cnode_inputs.push_back(parameter_from_cnode);
|
|
|
|
|
cnode_inputs->push_back(parameter_from_cnode);
|
|
|
|
|
(*other_graph_cnode)[anf] = parameter_from_cnode;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
|
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
|
|
|
|
// get primitive of old node
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs;
|
|
|
|
|
GetCNodeInfo(cnode, &cnode_inputs);
|
|
|
|
|
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
|
|
|
|
|
auto new_cnode = graph->NewCNode(cnode_inputs);
|
|
|
|
|
TraceManager::EndTrace();
|
|
|
|
@ -593,6 +602,42 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
|
|
|
|
|
return partial_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
std::vector<AnfNodePtr> cnode_inputs = {
|
|
|
|
|
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
|
|
|
|
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(attr_input);
|
|
|
|
|
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
|
|
|
|
auto switch_cnode = cnode_input->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(switch_cnode);
|
|
|
|
|
if (cnode->inputs().size() < 2) {
|
|
|
|
|
cnode_inputs = switch_cnode->inputs();
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
|
|
|
|
|
switch_cnode->input(kFirstDataInputIndex)};
|
|
|
|
|
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
|
|
|
|
|
auto node = switch_cnode->input(index);
|
|
|
|
|
// there is real input in call, should put it to true and false branch in switch
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
|
|
|
|
auto partial_node = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_node);
|
|
|
|
|
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
|
|
|
|
|
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
|
|
|
|
auto new_partial = graph->NewCNode(partial_inputs);
|
|
|
|
|
switch_inputs.emplace_back(new_partial);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (switch_inputs.size() < kSwitchInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
|
|
|
|
|
}
|
|
|
|
|
auto switch_node = graph->NewCNode(switch_inputs);
|
|
|
|
|
cnode_inputs.emplace_back(switch_node);
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
@ -618,32 +663,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
|
|
|
|
|
});
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
|
|
|
|
auto switch_cnode = cnode_input->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(switch_cnode);
|
|
|
|
|
if (cnode->inputs().size() < 2) {
|
|
|
|
|
cnode_inputs = switch_cnode->inputs();
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
|
|
|
|
|
switch_cnode->input(kFirstDataInputIndex)};
|
|
|
|
|
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
|
|
|
|
|
auto node = switch_cnode->input(index);
|
|
|
|
|
// there is real input in call, should put it to true and false branch in switch
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
|
|
|
|
auto partial_node = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_node);
|
|
|
|
|
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
|
|
|
|
|
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
|
|
|
|
auto new_partial = graph->NewCNode(partial_inputs);
|
|
|
|
|
switch_inputs.emplace_back(new_partial);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (switch_inputs.size() < kSwitchInputSize) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
|
|
|
|
|
}
|
|
|
|
|
auto switch_node = graph->NewCNode(switch_inputs);
|
|
|
|
|
cnode_inputs.emplace_back(switch_node);
|
|
|
|
|
return cnode_inputs;
|
|
|
|
|
return CreateCallSwitchInputs(cnode, graph);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
|
|
|
|
|
}
|
|
|
|
|