|
|
|
@ -16,9 +16,9 @@
|
|
|
|
|
|
|
|
|
|
#include "frontend/optimizer/ad/dfunctor.h"
|
|
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#include "ir/anf.h"
|
|
|
|
|
#include "utils/info.h"
|
|
|
|
@ -99,14 +99,23 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
|
|
|
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
|
|
|
|
|
fv_adjoint->second->RegisterKUser(node, 1);
|
|
|
|
|
auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()});
|
|
|
|
|
fv_adjoint->second->RegisterKUser(default_val, 1);
|
|
|
|
|
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, node, default_val});
|
|
|
|
|
auto fv_node = fv_adjoint->second->k();
|
|
|
|
|
auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
|
|
|
|
|
CNodePtr embed_node, default_val_node;
|
|
|
|
|
if (cached_envitem_iter != anfnode_to_envitem_.end()) {
|
|
|
|
|
embed_node = cached_envitem_iter->second.first;
|
|
|
|
|
default_val_node = cached_envitem_iter->second.second;
|
|
|
|
|
} else {
|
|
|
|
|
embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node});
|
|
|
|
|
default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node});
|
|
|
|
|
fv_adjoint->second->RegisterKUser(embed_node, 1);
|
|
|
|
|
fv_adjoint->second->RegisterKUser(default_val_node, 1);
|
|
|
|
|
anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
|
|
|
|
|
}
|
|
|
|
|
auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
|
|
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << node->ToString() << ".";
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
|
|
|
|
|
fv_adjoint->second->AccumulateDout(dfv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|