Optimize the performance after Grad due to mapping Primitive to K.

pull/9300/head
Zhang Qinghua 4 years ago
parent 8453b0d243
commit 829ab4492b

@ -232,22 +232,15 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
for (size_t i = 0; i < cnode_morph->size(); i++) { for (size_t i = 0; i < cnode_morph->size(); i++) {
auto node = cnode_morph->input(i); auto node = cnode_morph->input(i);
AdjointPtr node_adjoint = nullptr; AdjointPtr node_adjoint = nullptr;
AnfNodePtr k = nullptr; auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
if (IsValueNode<Primitive>(node)) { if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
k = MapToK(cnode_morph, i); node_adjoint = node_adjoint_iter->second;
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
anfnode_to_adjoin_[node] = node_adjoint;
} else { } else {
auto node_adjoint_iter = anfnode_to_adjoin_.find(node); // Input might be a CNode that needs to be handled previously.
if (node_adjoint_iter != anfnode_to_adjoin_.end()) { node_adjoint = MapMorphism(node);
node_adjoint = node_adjoint_iter->second;
} else {
// Input might be a CNode that needs to be handled previously.
node_adjoint = MapMorphism(node);
}
MS_EXCEPTION_IF_NULL(node_adjoint);
k = node_adjoint->k();
} }
MS_EXCEPTION_IF_NULL(node_adjoint);
AnfNodePtr k = node_adjoint->k();
if (k == nullptr) { if (k == nullptr) {
MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
} }
@ -537,93 +530,69 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
return nullptr; return nullptr;
} }
// Map func graph to K // Construct representation graph for {CNode, Index} of Primitive.
AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) {
auto f = func_graph_to_functor_.find(primal); auto primal = primitive_user->input(index);
if (!IsValueNode<Primitive>(primal)) {
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive.";
}
ScopeGuard scope_guard(primal->scope());
// Map Primitive to K
auto value_node = primal->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
need_cut_ = true;
}
auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_);
if (k_prim != nullptr) {
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
if (k_meta != nullptr) {
return NewValueNode(k_meta);
}
MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K.";
}
// Construct representation graph for ValueNode of FuncGraph.
AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) {
if (!IsValueNode<FuncGraph>(primal)) {
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph.";
}
ScopeGuard scope_guard(primal->scope());
// Map func graph to K
auto func_graph = GetValueNode<FuncGraphPtr>(primal);
auto f = func_graph_to_functor_.find(func_graph);
if (f != func_graph_to_functor_.end()) { if (f != func_graph_to_functor_.end()) {
MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << "."; MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << ".";
return NewValueNode(f->second->k_graph_); return NewValueNode(f->second->k_graph_);
} }
auto k_user_defined = KUserDefined(func_graph);
auto k_user_defined = KUserDefined(primal);
if (k_user_defined != nullptr) { if (k_user_defined != nullptr) {
MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << "."; MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << ".";
return NewValueNode(k_user_defined); return NewValueNode(k_user_defined);
} }
auto functor = std::make_shared<DFunctor>(func_graph, resources_);
auto functor = std::make_shared<DFunctor>(primal, resources_);
functor->Init(); functor->Init();
functor->MapObject(); functor->MapObject();
functor->MapMorphism(); functor->MapMorphism();
MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << "."; MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\"";
return NewValueNode(functor->k_graph_); return NewValueNode(functor->k_graph_);
} }
// Construct representation graph for primitive CNode. // Construct for ValueNode of Parameter.
AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) { AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) {
auto primal = primal_user->input(index); if (!primal->isa<Parameter>()) {
ScopeGuard scope_guard(primal->scope()); MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter.";
// Map primitive to K
if (IsValueNode<Primitive>(primal)) {
auto value_node = primal->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
need_cut_ = true;
}
auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_);
if (k_prim != nullptr) {
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
if (k_meta != nullptr) {
return NewValueNode(k_meta);
}
} }
return MapToK(primal);
}
// Construct representation graph for given node.
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
ScopeGuard scope_guard(primal->scope()); ScopeGuard scope_guard(primal->scope());
// Map primitive to K // Map Parameter to K
if (IsValueNode<Primitive>(primal)) { TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
auto value_node = primal->cast<ValueNodePtr>(); auto ret = k_graph_->add_parameter();
auto prim = GetValueNode<PrimitivePtr>(value_node); return ret;
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
need_cut_ = true;
}
auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_);
if (k_prim != nullptr) {
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
if (k_meta != nullptr) {
return NewValueNode(k_meta);
}
}
// Map func graph to K
if (IsValueNode<FuncGraph>(primal)) {
auto func_graph = GetValueNode<FuncGraphPtr>(primal);
auto k_func = MapToK(func_graph);
return k_func;
}
if (primal->isa<Parameter>()) {
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
auto ret = k_graph_->add_parameter();
return ret;
}
if (!primal->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode.";
}
return primal;
} }
bool DFunctor::IsInScope(const AnfNodePtr &node) { bool DFunctor::IsInScope(const AnfNodePtr &node) {
@ -664,7 +633,7 @@ void DFunctor::MapParamObject() {
for (auto &p : primal_graph_->parameters()) { for (auto &p : primal_graph_->parameters()) {
ScopeGuard scope_guard(p->scope()); ScopeGuard scope_guard(p->scope());
MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << "."; MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
auto adjoint = std::make_shared<Adjoint>(p, MapToK(p), tape_); auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_);
UpdateAdjoint(adjoint); UpdateAdjoint(adjoint);
anfnode_to_adjoin_[p] = adjoint; anfnode_to_adjoin_[p] = adjoint;
} }
@ -682,12 +651,32 @@ void DFunctor::MapValueObject() {
anfnode_to_adjoin_[node] = adjoint; anfnode_to_adjoin_[node] = adjoint;
continue; continue;
} }
// Skip Primitive.
if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) { AdjointPtr adjoint = nullptr;
continue; if (IsValueNode<Primitive>(node)) { // Primitive.
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
continue;
}
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
auto &users = manager->node_users()[node];
if (users.size() == 0) {
MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user.";
continue;
} else if (users.size() > 1) {
MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size();
}
auto cnode = users.begin()->first->cast<CNodePtr>(); // We just use the first user.
auto index = users.begin()->second;
adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_);
} else if (IsValueNode<FuncGraph>(node)) { // FuncGraph
MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << ".";
adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_);
} else if (node->isa<Parameter>()) { // Parameter, hardly reach here.
MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << ".";
adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_);
} else {
adjoint = std::make_shared<Adjoint>(node, node, tape_);
} }
MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << ".";
auto adjoint = std::make_shared<Adjoint>(node, MapToK(node), tape_);
UpdateAdjoint(adjoint); UpdateAdjoint(adjoint);
anfnode_to_adjoin_[node] = adjoint; anfnode_to_adjoin_[node] = adjoint;
} }

@ -81,12 +81,12 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
// Map AnfNode object from D category to K category. // Map CNode/Index of Primitive to K.
AnfNodePtr MapToK(const AnfNodePtr &primal); AnfNodePtr MapPrimitiveToK(const CNodePtr &primitive_user, size_t index);
// Map CNode object from D category to K category. // Map ValueNode of FuncGraph to K.
AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index); AnfNodePtr MapFuncGraphToK(const AnfNodePtr &primal);
// Map FuncGraph object from D category to K category. // Map ValueNode of Parameter to K.
AnfNodePtr MapToK(const FuncGraphPtr &primal); AnfNodePtr MapParameterToK(const AnfNodePtr &primal);
// MapObject impls. // MapObject impls.
void MapFvObject(); void MapFvObject();
void MapValueObject(); void MapValueObject();

Loading…
Cancel
Save