|
|
|
@ -71,6 +71,11 @@ void DFunctor::Init(bool is_top) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::Finish() {
|
|
|
|
|
CallDoutHoleOnTape();
|
|
|
|
|
EliminatePrimalGraph();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::Clear() {
|
|
|
|
|
func_graph_to_functor_.clear();
|
|
|
|
|
anfnode_to_adjoin_definition_.clear();
|
|
|
|
@ -728,10 +733,7 @@ void DFunctor::CallDoutHoleOnTape() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
FuncGraphPtr DFunctor::k_graph() {
|
|
|
|
|
CallDoutHoleOnTape();
|
|
|
|
|
return k_graph_;
|
|
|
|
|
}
|
|
|
|
|
FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
|
|
|
|
|
|
|
|
|
|
void DFunctor::BroadCastStopFlag() {
|
|
|
|
|
// As stop set expanding, all directly or indirectly stopped CNode will be cut off
|
|
|
|
@ -768,5 +770,28 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// To replace the primal graph with k graph
|
|
|
|
|
void DFunctor::EliminatePrimalGraph() {
|
|
|
|
|
auto k_vnode = NewValueNode(k_graph_);
|
|
|
|
|
auto idx0 = NewValueNode(SizeToInt(0));
|
|
|
|
|
auto imm0 = std::make_shared<Int32Imm>(0);
|
|
|
|
|
idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
|
|
|
|
|
auto manager = primal_graph_->manager();
|
|
|
|
|
auto users = primal_graph_->func_graph_cnodes_index();
|
|
|
|
|
for (auto &it : users) {
|
|
|
|
|
auto cnode = it.first->first->cast<CNodePtr>();
|
|
|
|
|
auto index = it.first->second;
|
|
|
|
|
auto vnode = cnode->inputs()[index];
|
|
|
|
|
if (index != 0) {
|
|
|
|
|
MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
cnode->set_input(0, k_vnode); // Replace primal graph with k graph
|
|
|
|
|
auto construct_wrapper = cnode->func_graph();
|
|
|
|
|
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0});
|
|
|
|
|
manager->Replace(cnode, getitem0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace ad
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|