optimize internaloutput for const-input-to-tensor pass

pull/8136/head
kswang 4 years ago
parent f020543f02
commit a98f871fe4

@ -92,6 +92,9 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
new_cnode->set_abstract(cnode->abstract()); new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope()); new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode); AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode, new_cnode);
}
return new_cnode; return new_cnode;
} }
return nullptr; return nullptr;

@ -474,9 +474,6 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
FrontBackendlMapUpdate(cnode, new_cnode); FrontBackendlMapUpdate(cnode, new_cnode);
} }
AnfAlgo::SetGraphId(graph_id_, cnode.get()); AnfAlgo::SetGraphId(graph_id_, cnode.get());
if (IsInternalOutput(cnode)) {
ReplaceInternalOutput(cnode, new_cnode);
}
return new_cnode; return new_cnode;
} }
@ -656,6 +653,9 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) { if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString(); MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString();
} }
if (IsInternalOutput(old_backend_anf)) {
ReplaceInternalOutput(old_backend_anf, new_backend_anf);
}
front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf; front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf]; backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
// delete old kernel // delete old kernel

Loading…
Cancel
Save