|
|
|
@ -38,6 +38,7 @@ namespace mindspore {
|
|
|
|
|
namespace ad {
|
|
|
|
|
std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
|
|
|
|
|
std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
|
|
|
|
|
FuncGraphSet DFunctor::scope_;
|
|
|
|
|
|
|
|
|
|
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
|
|
|
|
|
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
|
|
|
|
@ -55,11 +56,15 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
|
|
|
|
void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
|
|
|
|
|
func_graph_to_functor_[primal_graph_] = functor;
|
|
|
|
|
is_top_ = is_top;
|
|
|
|
|
if (is_top) {
|
|
|
|
|
scope_ = primal_graph_->scope();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::Clear() {
|
|
|
|
|
func_graph_to_functor_.clear();
|
|
|
|
|
anfnode_to_adjoin_definition_.clear();
|
|
|
|
|
scope_.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
|
|
|
@ -95,11 +100,48 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
|
|
|
|
fv_adjoint->second->AccumulateDout(dfv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
|
|
|
|
|
// Take switch_layer as a set of candidate functions.
|
|
|
|
|
auto input = cnode_morph->input(2);
|
|
|
|
|
if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
auto tuple_graphs = input->cast<CNodePtr>();
|
|
|
|
|
for (size_t i = 1; i < tuple_graphs->size(); ++i) {
|
|
|
|
|
auto graph = tuple_graphs->input(i);
|
|
|
|
|
if (!IsValueNode<FuncGraph>(graph)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
|
|
|
|
|
<< " as the " << i << "th element.";
|
|
|
|
|
}
|
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(graph);
|
|
|
|
|
auto functor = func_graph_to_functor_.find(func_graph);
|
|
|
|
|
if (functor == func_graph_to_functor_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
|
|
|
|
|
<< func_graph->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
// Consider direct and indirect fvs.
|
|
|
|
|
for (auto fv : func_graph->free_variables_nodes()) {
|
|
|
|
|
BackPropagateFv(fv, env);
|
|
|
|
|
}
|
|
|
|
|
for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
|
|
|
|
|
MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
|
|
|
|
|
<< indirect_fv.first->ToString() << ".";
|
|
|
|
|
BackPropagateFv(indirect_fv.first, env);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
|
|
|
|
|
auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)});
|
|
|
|
|
// Call with delimited continuation dout.
|
|
|
|
|
auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
|
|
|
|
|
node_adjoint->RegisterDoutUser(bprop_app, 1);
|
|
|
|
|
// Special case for switch_layer
|
|
|
|
|
if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
|
|
|
|
|
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)});
|
|
|
|
|
BackPropagateSwitchLayer(cnode_morph, din);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
|
|
|
|
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))});
|
|
|
|
|
auto input = cnode_morph->input(i);
|
|
|
|
@ -402,6 +444,11 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
|
|
|
|
return primal;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool DFunctor::IsInScope(const AnfNodePtr &node) {
|
|
|
|
|
return std::any_of(scope_.begin(), scope_.end(),
|
|
|
|
|
[&](const FuncGraphPtr &graph) { return node->func_graph() == graph; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::MapFvObject() {
|
|
|
|
|
// Map free variable.
|
|
|
|
|
const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
|
|
|
|
@ -414,8 +461,8 @@ void DFunctor::MapFvObject() {
|
|
|
|
|
if (parent_adjoint != nullptr) {
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
|
|
|
|
|
} else {
|
|
|
|
|
if (is_top_) {
|
|
|
|
|
// Top graph for ad, add adjoint for free variables.
|
|
|
|
|
if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) {
|
|
|
|
|
// Out of ad scope, add adjoint for free variables.
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(node, node, tape_);
|
|
|
|
|
UpdateAdjoint(adjoint);
|
|
|
|
|
} else {
|
|
|
|
|