|
|
|
@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
|
|
|
|
|
return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
|
|
|
|
|
auto final_graph = GetGraph(final_graph_id_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(final_graph);
|
|
|
|
|
if (!utils::isa<AnfNodePtr>(output)) {
|
|
|
|
|
if (!utils::isa<ValuePtr>(output)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
|
|
|
|
}
|
|
|
|
|
auto value_ptr = utils::cast<ValuePtr>(output);
|
|
|
|
|
auto value_node = NewValueNode(value_ptr);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
|
|
|
value_node->set_kernel_info(kernel_info);
|
|
|
|
|
value_node->set_abstract(abstract::FromValue(value_ptr));
|
|
|
|
|
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
|
|
|
|
|
final_graph->set_executable(false);
|
|
|
|
|
MS_LOG(INFO) << "Not anf output[" << output.ToString() << "]";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) {
|
|
|
|
|
// get the backend anf node related to the output node of front
|
|
|
|
|
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
|
|
|
|
auto output_from_graph_id = GetGraphIdByNode(output_anf_node);
|
|
|
|
|
auto output_from_graph_id = GetGraphIdByNode(node);
|
|
|
|
|
auto output_from_graph = GetGraph(output_from_graph_id);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_anf_node);
|
|
|
|
|
MS_LOG(INFO) << "Set the output[" << output_anf_node->DebugString() << "] of graph[" << output_from_graph_id
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id
|
|
|
|
|
<< "] to final graph";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_from_graph);
|
|
|
|
|
auto final_graph = GetGraph(final_graph_id_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(final_graph);
|
|
|
|
|
// if output is from final graph,it remarks no child graph exist
|
|
|
|
|
if (final_graph_id_ == output_from_graph_id) {
|
|
|
|
|
MS_LOG(INFO) << "No child graph,output is " << output_anf_node->DebugString();
|
|
|
|
|
final_graph->set_output(ConstructOutput({output_anf_node}, final_graph));
|
|
|
|
|
MS_LOG(INFO) << "No child graph,output is " << node->DebugString();
|
|
|
|
|
final_graph->set_output(ConstructOutput({node}, final_graph));
|
|
|
|
|
final_graph->set_executable(false);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
final_graph->set_output(output_from_graph->output());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::SetFinalGraphOutput(const ValuePtr &value) {
|
|
|
|
|
auto value_node = NewValueNode(value);
|
|
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
|
|
|
value_node->set_kernel_info(kernel_info);
|
|
|
|
|
value_node->set_abstract(abstract::FromValue(value));
|
|
|
|
|
auto final_graph = GetGraph(final_graph_id_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(final_graph);
|
|
|
|
|
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
|
|
|
|
|
final_graph->set_executable(false);
|
|
|
|
|
MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) {
|
|
|
|
|
for (auto &output : vec_output) {
|
|
|
|
|
if (utils::isa<AnfNodePtr>(output)) {
|
|
|
|
|
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
|
|
|
|
SetFinalGraphOutput(output_anf_node);
|
|
|
|
|
} else if (utils::isa<ValuePtr>(output)) {
|
|
|
|
|
auto value = utils::cast<ValuePtr>(output);
|
|
|
|
|
SetFinalGraphOutput(value);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::SetFinalGraphOutput(const BaseRef &output) {
|
|
|
|
|
if (utils::isa<AnfNodePtr>(output)) {
|
|
|
|
|
auto output_anf_node = utils::cast<AnfNodePtr>(output);
|
|
|
|
|
SetFinalGraphOutput(output_anf_node);
|
|
|
|
|
} else if (utils::isa<ValuePtr>(output)) {
|
|
|
|
|
auto value = utils::cast<ValuePtr>(output);
|
|
|
|
|
SetFinalGraphOutput(value);
|
|
|
|
|
} else if (utils::isa<VectorRef>(output)) {
|
|
|
|
|
auto vec_output = utils::cast<VectorRef>(output);
|
|
|
|
|
SetFinalGraphOutput(vec_output);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) {
|
|
|
|
|
auto it = graphs_.find(graph_id);
|
|
|
|
|
if (it == graphs_.end()) {
|
|
|
|
|