|
|
|
@ -232,22 +232,15 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|
|
|
|
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
|
|
|
|
auto node = cnode_morph->input(i);
|
|
|
|
|
AdjointPtr node_adjoint = nullptr;
|
|
|
|
|
AnfNodePtr k = nullptr;
|
|
|
|
|
if (IsValueNode<Primitive>(node)) {
|
|
|
|
|
k = MapToK(cnode_morph, i);
|
|
|
|
|
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
|
|
|
|
|
anfnode_to_adjoin_[node] = node_adjoint;
|
|
|
|
|
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
|
|
|
|
|
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
|
|
|
|
|
node_adjoint = node_adjoint_iter->second;
|
|
|
|
|
} else {
|
|
|
|
|
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
|
|
|
|
|
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
|
|
|
|
|
node_adjoint = node_adjoint_iter->second;
|
|
|
|
|
} else {
|
|
|
|
|
// Input might be a CNode that needs to be handled previously.
|
|
|
|
|
node_adjoint = MapMorphism(node);
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_adjoint);
|
|
|
|
|
k = node_adjoint->k();
|
|
|
|
|
// Input might be a CNode that needs to be handled previously.
|
|
|
|
|
node_adjoint = MapMorphism(node);
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_adjoint);
|
|
|
|
|
AnfNodePtr k = node_adjoint->k();
|
|
|
|
|
if (k == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
@ -537,93 +530,69 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Map func graph to K
|
|
|
|
|
AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
|
|
|
|
|
auto f = func_graph_to_functor_.find(primal);
|
|
|
|
|
// Construct representation graph for {CNode, Index} of Primitive.
|
|
|
|
|
AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) {
|
|
|
|
|
auto primal = primitive_user->input(index);
|
|
|
|
|
if (!IsValueNode<Primitive>(primal)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive.";
|
|
|
|
|
}
|
|
|
|
|
ScopeGuard scope_guard(primal->scope());
|
|
|
|
|
// Map Primitive to K
|
|
|
|
|
auto value_node = primal->cast<ValueNodePtr>();
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
|
|
|
|
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
|
|
|
|
|
need_cut_ = true;
|
|
|
|
|
}
|
|
|
|
|
auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_);
|
|
|
|
|
if (k_prim != nullptr) {
|
|
|
|
|
return NewValueNode(k_prim);
|
|
|
|
|
}
|
|
|
|
|
// When failed to find k_prim, try k_meta.
|
|
|
|
|
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
|
|
|
|
|
if (k_meta != nullptr) {
|
|
|
|
|
return NewValueNode(k_meta);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Construct representation graph for ValueNode of FuncGraph.
|
|
|
|
|
AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) {
|
|
|
|
|
if (!IsValueNode<FuncGraph>(primal)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph.";
|
|
|
|
|
}
|
|
|
|
|
ScopeGuard scope_guard(primal->scope());
|
|
|
|
|
// Map func graph to K
|
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(primal);
|
|
|
|
|
auto f = func_graph_to_functor_.find(func_graph);
|
|
|
|
|
if (f != func_graph_to_functor_.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << ".";
|
|
|
|
|
MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << ".";
|
|
|
|
|
return NewValueNode(f->second->k_graph_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto k_user_defined = KUserDefined(primal);
|
|
|
|
|
auto k_user_defined = KUserDefined(func_graph);
|
|
|
|
|
if (k_user_defined != nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << ".";
|
|
|
|
|
MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << ".";
|
|
|
|
|
return NewValueNode(k_user_defined);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto functor = std::make_shared<DFunctor>(primal, resources_);
|
|
|
|
|
auto functor = std::make_shared<DFunctor>(func_graph, resources_);
|
|
|
|
|
functor->Init();
|
|
|
|
|
functor->MapObject();
|
|
|
|
|
functor->MapMorphism();
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << ".";
|
|
|
|
|
MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\"";
|
|
|
|
|
return NewValueNode(functor->k_graph_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Construct representation graph for primitive CNode.
|
|
|
|
|
AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) {
|
|
|
|
|
auto primal = primal_user->input(index);
|
|
|
|
|
ScopeGuard scope_guard(primal->scope());
|
|
|
|
|
// Map primitive to K
|
|
|
|
|
if (IsValueNode<Primitive>(primal)) {
|
|
|
|
|
auto value_node = primal->cast<ValueNodePtr>();
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
|
|
|
|
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
|
|
|
|
|
need_cut_ = true;
|
|
|
|
|
}
|
|
|
|
|
auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_);
|
|
|
|
|
if (k_prim != nullptr) {
|
|
|
|
|
return NewValueNode(k_prim);
|
|
|
|
|
}
|
|
|
|
|
// When failed to find k_prim, try k_meta.
|
|
|
|
|
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
|
|
|
|
|
if (k_meta != nullptr) {
|
|
|
|
|
return NewValueNode(k_meta);
|
|
|
|
|
}
|
|
|
|
|
// Construct for ValueNode of Parameter.
|
|
|
|
|
AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) {
|
|
|
|
|
if (!primal->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter.";
|
|
|
|
|
}
|
|
|
|
|
return MapToK(primal);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Construct representation graph for given node.
|
|
|
|
|
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
|
|
|
|
ScopeGuard scope_guard(primal->scope());
|
|
|
|
|
// Map primitive to K
|
|
|
|
|
if (IsValueNode<Primitive>(primal)) {
|
|
|
|
|
auto value_node = primal->cast<ValueNodePtr>();
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
|
|
|
|
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
|
|
|
|
|
need_cut_ = true;
|
|
|
|
|
}
|
|
|
|
|
auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_);
|
|
|
|
|
if (k_prim != nullptr) {
|
|
|
|
|
return NewValueNode(k_prim);
|
|
|
|
|
}
|
|
|
|
|
// When failed to find k_prim, try k_meta.
|
|
|
|
|
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
|
|
|
|
|
if (k_meta != nullptr) {
|
|
|
|
|
return NewValueNode(k_meta);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Map func graph to K
|
|
|
|
|
if (IsValueNode<FuncGraph>(primal)) {
|
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(primal);
|
|
|
|
|
auto k_func = MapToK(func_graph);
|
|
|
|
|
return k_func;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (primal->isa<Parameter>()) {
|
|
|
|
|
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
|
|
|
|
|
auto ret = k_graph_->add_parameter();
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!primal->isa<ValueNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode.";
|
|
|
|
|
}
|
|
|
|
|
return primal;
|
|
|
|
|
// Map Parameter to K
|
|
|
|
|
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
|
|
|
|
|
auto ret = k_graph_->add_parameter();
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool DFunctor::IsInScope(const AnfNodePtr &node) {
|
|
|
|
@ -664,7 +633,7 @@ void DFunctor::MapParamObject() {
|
|
|
|
|
for (auto &p : primal_graph_->parameters()) {
|
|
|
|
|
ScopeGuard scope_guard(p->scope());
|
|
|
|
|
MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
|
|
|
|
|
auto adjoint = std::make_shared<Adjoint>(p, MapToK(p), tape_);
|
|
|
|
|
auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_);
|
|
|
|
|
UpdateAdjoint(adjoint);
|
|
|
|
|
anfnode_to_adjoin_[p] = adjoint;
|
|
|
|
|
}
|
|
|
|
@ -682,12 +651,32 @@ void DFunctor::MapValueObject() {
|
|
|
|
|
anfnode_to_adjoin_[node] = adjoint;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// Skip Primitive.
|
|
|
|
|
if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
|
AdjointPtr adjoint = nullptr;
|
|
|
|
|
if (IsValueNode<Primitive>(node)) { // Primitive.
|
|
|
|
|
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
|
|
|
|
|
auto &users = manager->node_users()[node];
|
|
|
|
|
if (users.size() == 0) {
|
|
|
|
|
MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user.";
|
|
|
|
|
continue;
|
|
|
|
|
} else if (users.size() > 1) {
|
|
|
|
|
MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size();
|
|
|
|
|
}
|
|
|
|
|
auto cnode = users.begin()->first->cast<CNodePtr>(); // We just use the first user.
|
|
|
|
|
auto index = users.begin()->second;
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_);
|
|
|
|
|
} else if (IsValueNode<FuncGraph>(node)) { // FuncGraph
|
|
|
|
|
MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << ".";
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_);
|
|
|
|
|
} else if (node->isa<Parameter>()) { // Parameter, hardly reach here.
|
|
|
|
|
MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << ".";
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_);
|
|
|
|
|
} else {
|
|
|
|
|
adjoint = std::make_shared<Adjoint>(node, node, tape_);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << ".";
|
|
|
|
|
auto adjoint = std::make_shared<Adjoint>(node, MapToK(node), tape_);
|
|
|
|
|
UpdateAdjoint(adjoint);
|
|
|
|
|
anfnode_to_adjoin_[node] = adjoint;
|
|
|
|
|
}
|
|
|
|
|