Optimize the performance after Grad due to mapping Primitive to K.

pull/9300/head
Zhang Qinghua 4 years ago
parent 8453b0d243
commit 829ab4492b

@ -232,22 +232,15 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
for (size_t i = 0; i < cnode_morph->size(); i++) {
auto node = cnode_morph->input(i);
AdjointPtr node_adjoint = nullptr;
AnfNodePtr k = nullptr;
if (IsValueNode<Primitive>(node)) {
k = MapToK(cnode_morph, i);
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
anfnode_to_adjoin_[node] = node_adjoint;
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
node_adjoint = node_adjoint_iter->second;
} else {
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
node_adjoint = node_adjoint_iter->second;
} else {
// Input might be a CNode that needs to be handled previously.
node_adjoint = MapMorphism(node);
}
MS_EXCEPTION_IF_NULL(node_adjoint);
k = node_adjoint->k();
// Input might be a CNode that needs to be handled previously.
node_adjoint = MapMorphism(node);
}
MS_EXCEPTION_IF_NULL(node_adjoint);
AnfNodePtr k = node_adjoint->k();
if (k == nullptr) {
MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
}
@ -537,93 +530,69 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
return nullptr;
}
// Map func graph to K
AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
auto f = func_graph_to_functor_.find(primal);
// Construct representation graph for {CNode, Index} of Primitive.
AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) {
auto primal = primitive_user->input(index);
if (!IsValueNode<Primitive>(primal)) {
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive.";
}
ScopeGuard scope_guard(primal->scope());
// Map Primitive to K
auto value_node = primal->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
need_cut_ = true;
}
auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_);
if (k_prim != nullptr) {
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
if (k_meta != nullptr) {
return NewValueNode(k_meta);
}
MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K.";
}
// Construct representation graph for ValueNode of FuncGraph.
AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) {
if (!IsValueNode<FuncGraph>(primal)) {
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph.";
}
ScopeGuard scope_guard(primal->scope());
// Map func graph to K
auto func_graph = GetValueNode<FuncGraphPtr>(primal);
auto f = func_graph_to_functor_.find(func_graph);
if (f != func_graph_to_functor_.end()) {
MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << ".";
MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << ".";
return NewValueNode(f->second->k_graph_);
}
auto k_user_defined = KUserDefined(primal);
auto k_user_defined = KUserDefined(func_graph);
if (k_user_defined != nullptr) {
MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << ".";
MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << ".";
return NewValueNode(k_user_defined);
}
auto functor = std::make_shared<DFunctor>(primal, resources_);
auto functor = std::make_shared<DFunctor>(func_graph, resources_);
functor->Init();
functor->MapObject();
functor->MapMorphism();
MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << ".";
MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\"";
return NewValueNode(functor->k_graph_);
}
// Construct representation graph for primitive CNode.
AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) {
auto primal = primal_user->input(index);
ScopeGuard scope_guard(primal->scope());
// Map primitive to K
if (IsValueNode<Primitive>(primal)) {
auto value_node = primal->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
need_cut_ = true;
}
auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_);
if (k_prim != nullptr) {
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
if (k_meta != nullptr) {
return NewValueNode(k_meta);
}
// Construct for ValueNode of Parameter.
AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) {
if (!primal->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter.";
}
return MapToK(primal);
}
// Construct representation graph for given node.
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
ScopeGuard scope_guard(primal->scope());
// Map primitive to K
if (IsValueNode<Primitive>(primal)) {
auto value_node = primal->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
need_cut_ = true;
}
auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_);
if (k_prim != nullptr) {
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
if (k_meta != nullptr) {
return NewValueNode(k_meta);
}
}
// Map func graph to K
if (IsValueNode<FuncGraph>(primal)) {
auto func_graph = GetValueNode<FuncGraphPtr>(primal);
auto k_func = MapToK(func_graph);
return k_func;
}
if (primal->isa<Parameter>()) {
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
auto ret = k_graph_->add_parameter();
return ret;
}
if (!primal->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode.";
}
return primal;
// Map Parameter to K
TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
auto ret = k_graph_->add_parameter();
return ret;
}
bool DFunctor::IsInScope(const AnfNodePtr &node) {
@ -664,7 +633,7 @@ void DFunctor::MapParamObject() {
for (auto &p : primal_graph_->parameters()) {
ScopeGuard scope_guard(p->scope());
MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
auto adjoint = std::make_shared<Adjoint>(p, MapToK(p), tape_);
auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_);
UpdateAdjoint(adjoint);
anfnode_to_adjoin_[p] = adjoint;
}
@ -682,12 +651,32 @@ void DFunctor::MapValueObject() {
anfnode_to_adjoin_[node] = adjoint;
continue;
}
// Skip Primitive.
if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
continue;
AdjointPtr adjoint = nullptr;
if (IsValueNode<Primitive>(node)) { // Primitive.
if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
continue;
}
MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
auto &users = manager->node_users()[node];
if (users.size() == 0) {
MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user.";
continue;
} else if (users.size() > 1) {
MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size();
}
auto cnode = users.begin()->first->cast<CNodePtr>(); // We just use the first user.
auto index = users.begin()->second;
adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_);
} else if (IsValueNode<FuncGraph>(node)) { // FuncGraph
MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << ".";
adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_);
} else if (node->isa<Parameter>()) { // Parameter, hardly reach here.
MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << ".";
adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_);
} else {
adjoint = std::make_shared<Adjoint>(node, node, tape_);
}
MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << ".";
auto adjoint = std::make_shared<Adjoint>(node, MapToK(node), tape_);
UpdateAdjoint(adjoint);
anfnode_to_adjoin_[node] = adjoint;
}

@ -81,12 +81,12 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
// Map AnfNode object from D category to K category.
AnfNodePtr MapToK(const AnfNodePtr &primal);
// Map CNode object from D category to K category.
AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index);
// Map FuncGraph object from D category to K category.
AnfNodePtr MapToK(const FuncGraphPtr &primal);
// Map CNode/Index of Primitive to K.
AnfNodePtr MapPrimitiveToK(const CNodePtr &primitive_user, size_t index);
// Map ValueNode of FuncGraph to K.
AnfNodePtr MapFuncGraphToK(const AnfNodePtr &primal);
// Map ValueNode of Parameter to K.
AnfNodePtr MapParameterToK(const AnfNodePtr &primal);
// MapObject impls.
void MapFvObject();
void MapValueObject();

Loading…
Cancel
Save