!7200 cache embed and zeros_like node for free variable in ad

Merge pull request !7200 from xychow/use-cached-envitem-in-ad
pull/7200/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2818e0a7e6

@ -16,9 +16,9 @@
#include "frontend/optimizer/ad/dfunctor.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "utils/info.h"
@ -99,14 +99,23 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
}
}
auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
fv_adjoint->second->RegisterKUser(node, 1);
auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()});
fv_adjoint->second->RegisterKUser(default_val, 1);
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, node, default_val});
auto fv_node = fv_adjoint->second->k();
auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
CNodePtr embed_node, default_val_node;
if (cached_envitem_iter != anfnode_to_envitem_.end()) {
embed_node = cached_envitem_iter->second.first;
default_val_node = cached_envitem_iter->second.second;
} else {
embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node});
default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node});
fv_adjoint->second->RegisterKUser(embed_node, 1);
fv_adjoint->second->RegisterKUser(default_val_node, 1);
anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
}
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << node->ToString() << ".";
MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
fv_adjoint->second->AccumulateDout(dfv);
}

@ -24,6 +24,7 @@
#include <unordered_map>
#include <vector>
#include <iostream>
#include <utility>
#include "ir/anf.h"
#include "ir/meta_func_graph.h"
@ -100,6 +101,9 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_indirect_fv_;
// Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvGetItemTransform in optimizer
// can hit its cache if fv_node is same.
std::unordered_map<AnfNodePtr, std::pair<CNodePtr, CNodePtr>> anfnode_to_envitem_;
FuncGraphPtr primal_graph_;
// K object for primal_graph_;
FuncGraphPtr k_graph_;

Loading…
Cancel
Save