|
|
|
@ -91,21 +91,32 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
|
|
|
|
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
|
|
|
|
|
<< " " << fv->ToString() << ".";
|
|
|
|
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
|
|
|
|
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
|
|
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
|
|
|
|
auto parent_adjoint = FindAdjoint(fv);
|
|
|
|
|
AdjointPtr adjoint = nullptr;
|
|
|
|
|
if (parent_adjoint != nullptr) {
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
|
|
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
|
|
|
|
|
|
|
|
|
|
if (fv->func_graph() == primal_graph_) {
|
|
|
|
|
// If this fv is not mapped by MapMorphism because of cnode order, then map it now.
|
|
|
|
|
(void)MapMorphism(fv);
|
|
|
|
|
fv_adjoint = anfnode_to_adjoin_.find(fv);
|
|
|
|
|
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
|
|
|
|
|
<< fv->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
|
|
|
|
|
} else {
|
|
|
|
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
|
|
|
|
if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
|
|
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
|
|
|
|
auto parent_adjoint = FindAdjoint(fv);
|
|
|
|
|
AdjointPtr adjoint = nullptr;
|
|
|
|
|
if (parent_adjoint != nullptr) {
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
|
|
|
|
|
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
|
|
|
|
|
}
|
|
|
|
|
anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
|
|
|
|
|
fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto fv_node = fv_adjoint->second->k();
|
|
|
|
|