|
|
|
@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|
|
|
|
return node_adjoint;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
|
|
|
|
|
// Do not care about non-CNode
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// Do not care about kPrimReturn
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto &users = primal_graph_->manager()->node_users()[node];
|
|
|
|
|
// Do not care about isolated morphisms
|
|
|
|
|
if (users.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// Not free if it's used by some node in primal_graph
|
|
|
|
|
bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
|
|
|
|
|
auto &user = kv.first;
|
|
|
|
|
return user->func_graph() == primal_graph_;
|
|
|
|
|
});
|
|
|
|
|
return !nonfree;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::MapFreeMorphism() {
|
|
|
|
|
// Handle cnode not attached to output, that might be refered in other functions.
|
|
|
|
|
for (auto &node : primal_graph_->nodes()) {
|
|
|
|
|
auto adjoint = FindAdjoint(node);
|
|
|
|
|
if (adjoint != nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "MapFreeMorphism noncnode not mapped after MapMorphism " << node->ToString() << " "
|
|
|
|
|
<< node->type_name() << ".";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
|
|
|
|
|
if (!IsFreeMorphism(node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
|
|
|
|
@ -256,9 +269,10 @@ void DFunctor::MapMorphism() {
|
|
|
|
|
// Set stop_gradient before MapMorphism.
|
|
|
|
|
BroadCastStopFlag();
|
|
|
|
|
|
|
|
|
|
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
|
|
|
|
|
MapFreeMorphism();
|
|
|
|
|
// Handle morphism from output.
|
|
|
|
|
(void)MapMorphism(primal_graph_->output());
|
|
|
|
|
MapFreeMorphism();
|
|
|
|
|
|
|
|
|
|
// Construct K for primal_graph_
|
|
|
|
|
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
|
|
|
|
@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
|
|
|
|
const size_t param_diff = 1;
|
|
|
|
|
if (bprop_graph->output()->isa<CNode>() &&
|
|
|
|
|
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope "
|
|
|
|
|
<< primal->output()->scope()->name()
|
|
|
|
|
<< " output must be a tuple and output number should be the same with inputs.";
|
|
|
|
|
// It does not matter with the final tangents, just a tip for debugging
|
|
|
|
|
MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope "
|
|
|
|
|
<< primal->output()->scope()->name()
|
|
|
|
|
<< " output must be a tuple and output number should be the same with inputs.";
|
|
|
|
|
}
|
|
|
|
|
resources_->manager()->AddFuncGraph(bprop_graph);
|
|
|
|
|
|
|
|
|
|