|
|
|
@ -288,6 +288,22 @@ bool ExistSummaryNode(const KernelGraph *graph) {
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
const auto &node_inputs = cnode->inputs();
|
|
|
|
|
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
|
|
|
|
if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
GraphId SessionBasic::graph_sum_ = 0;
|
|
|
|
@ -354,8 +370,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
std::vector<AnfNodePtr> parameters;
|
|
|
|
|
std::vector<AnfNodePtr> pre_graph_out = {node};
|
|
|
|
|
if (IgnoreCreateParameterForMakeTuple(node)) {
|
|
|
|
|
pre_graph_out.clear();
|
|
|
|
|
}
|
|
|
|
|
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
|
|
|
|
|
if (!AnfAlgo::IsRealKernel(node)) {
|
|
|
|
|
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
|
|
|
|
|
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
|
|
|
|
|
}
|
|
|
|
|
auto valid_inputs = graph->MutableValidInputs();
|
|
|
|
@ -431,7 +450,8 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool
|
|
|
|
|
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
|
|
|
|
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
|
|
|
|
|
if (parameters.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "No parameter exist!!";
|
|
|
|
|
MS_LOG(INFO) << "Empty parameter from cnode";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (parameters.size() == 1) {
|
|
|
|
|
return parameters[0];
|
|
|
|
@ -505,11 +525,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|
|
|
|
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
|
|
|
|
|
continue;
|
|
|
|
|
} else if (optimize_control_depend) {
|
|
|
|
|
cnode_inputs.push_back(NewValueNode(MakeValue(input_idx)));
|
|
|
|
|
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
|
|
|
|
} else {
|
|
|
|
|
*from_other_graph = true;
|
|
|
|
|
// the input node is a cnode from other graph
|
|
|
|
|
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
|
|
|
|
|
if (parameter_from_cnode == nullptr) {
|
|
|
|
|
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx)));
|
|
|
|
|
}
|
|
|
|
|
cnode_inputs.push_back(parameter_from_cnode);
|
|
|
|
|
(*other_graph_cnode)[anf] = parameter_from_cnode;
|
|
|
|
|
}
|
|
|
|
@ -878,7 +901,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|
|
|
|
auto tensor = inputs[i];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor);
|
|
|
|
|
auto input_node = input_nodes[i];
|
|
|
|
|
if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
|
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
|
|
|
|
if (ms_context->execution_mode() == kPynativeMode ||
|
|
|
|
|
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
|
|
|
|
|