From 93cd019855a73e49891ae295f06bb70b20803712 Mon Sep 17 00:00:00 2001 From: es chow Date: Mon, 12 Oct 2020 10:52:55 +0800 Subject: [PATCH] add env_item cache for fv_node --- .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 23 +++++++++++++------ .../ccsrc/frontend/optimizer/ad/dfunctor.h | 4 ++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index b38728805f..811252c55b 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -16,9 +16,9 @@ #include "frontend/optimizer/ad/dfunctor.h" +#include #include #include -#include #include "ir/anf.h" #include "utils/info.h" @@ -99,14 +99,23 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); } } - auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); - fv_adjoint->second->RegisterKUser(node, 1); - auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()}); - fv_adjoint->second->RegisterKUser(default_val, 1); - auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, node, default_val}); + auto fv_node = fv_adjoint->second->k(); + auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node); + CNodePtr embed_node, default_val_node; + if (cached_envitem_iter != anfnode_to_envitem_.end()) { + embed_node = cached_envitem_iter->second.first; + default_val_node = cached_envitem_iter->second.second; + } else { + embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node}); + default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node}); + fv_adjoint->second->RegisterKUser(embed_node, 1); + fv_adjoint->second->RegisterKUser(default_val_node, 1); + anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node); + } + auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node}); MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv " << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << node->ToString() << "."; + MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << "."; fv_adjoint->second->AccumulateDout(dfv); } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 79c2ea8f38..70be856a29 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "ir/anf.h" #include "ir/meta_func_graph.h" @@ -100,6 +101,9 @@ class DFunctor : public std::enable_shared_from_this { std::unordered_map anfnode_to_adjoin_; // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. std::unordered_map anfnode_to_adjoin_indirect_fv_; + // Cache for fv node -> pair, zeros_like>, so EnvGetItemTransform in optimizer + // can hit its cache if fv_node is same. + std::unordered_map> anfnode_to_envitem_; FuncGraphPtr primal_graph_; // K object for primal_graph_; FuncGraphPtr k_graph_;