!4572 resolve output twice out of memory issue

Merge pull request !4572 from wangqiuliang/resolve-output-twice-out-of-memory
pull/4572/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 8e80f80f1c

@ -25,6 +25,7 @@
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/manager.h" #include "ir/manager.h"
#include "pipeline/jit/resource.h" #include "pipeline/jit/resource.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/ad/adjoint.h" #include "frontend/optimizer/ad/adjoint.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
@ -218,7 +219,8 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
auto k_app = k_graph_->NewCNode(inputs); auto k_app = k_graph_->NewCNode(inputs);
TraceManager::EndTrace(); TraceManager::EndTrace();
ReplaceEquivdout(k_app, cnode_morph->forward()); ReplaceEquivdout(k_app, cnode_morph);
cnode_morph->set_forward(nullptr, "");
for (size_t i = 0; i < param_adjoints.size(); ++i) { for (size_t i = 0; i < param_adjoints.size(); ++i) {
param_adjoints[i]->RegisterKUser(k_app, i); param_adjoints[i]->RegisterKUser(k_app, i);
} }
@ -240,7 +242,9 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return node_adjoint; return node_adjoint;
} }
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward) { void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
auto forward = cnode_morph->forward().first;
auto forward_id = cnode_morph->forward().second;
if (forward == nullptr) { if (forward == nullptr) {
return; return;
} }
@ -265,10 +269,44 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward)
auto equivdout = cnode_input->cast<CNodePtr>(); auto equivdout = cnode_input->cast<CNodePtr>();
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg); auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
auto manager = Manage({fg, func_graph}, false); auto manager = Manage({fg, func_graph}, false);
auto ref_size = manager->node_users()[equivdout].size();
auto forward_value = forward;
if (!forward_id.empty() && ref_size > 1) {
auto inst = pynative::PynativeExecutor::GetInstance();
inst->SaveOpForwardValue(forward_id, forward_value);
}
if (ref_size < 2) {
auto tensor = forward->cast<tensor::TensorPtr>();
if (tensor != nullptr) {
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape());
forward_value = new_tensor;
}
}
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward; MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
auto value_node = NewValueNode(forward); auto value_node = NewValueNode(forward_value);
value_node->set_has_new_value(true); value_node->set_has_new_value(true);
manager->Replace(equivdout, value_node); manager->Replace(equivdout, value_node);
auto paras = fg->parameters();
auto inputs_value = cnode_morph->inputs_value();
if (inputs_value.size() == 0) {
return;
}
if (inputs_value.size() != paras.size()) {
MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " is not equal to inputs size:" << inputs_value.size();
}
for (size_t i = 0; i < paras.size(); i++) {
auto para_ref_size = manager->node_users()[paras[i]].size();
auto input_value = inputs_value[i];
if (para_ref_size > 0 && input_value.first != nullptr) {
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
auto inst = pynative::PynativeExecutor::GetInstance();
inst->SaveOpForwardValue(input_value.second, input_value.first);
auto input_value_node = NewValueNode(input_value.first);
manager->Replace(paras[i], input_value_node);
}
}
cnode_morph->clear_inputs_value();
return;
} }
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {

@ -95,7 +95,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
// Update k hole with adjoint_definition, only applied in recursive case. // Update k hole with adjoint_definition, only applied in recursive case.
void UpdateAdjoint(const AdjointPtr &adjoint_definition); void UpdateAdjoint(const AdjointPtr &adjoint_definition);
void CallDoutHoleOnTape(); void CallDoutHoleOnTape();
void ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward); void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph);
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_; std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.

@ -724,18 +724,14 @@ void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::ob
set_pyobj(curr_g_, obj_id); set_pyobj(curr_g_, obj_id);
} }
void PynativeExecutor::SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) { void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value) {
auto id = GetOpId(op_exec_info); auto iter = op_forward_map_.find(id);
int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int>();
auto op = std::to_string(graph_id) + id;
op.append(std::to_string(op_id_map_[id]));
auto iter = op_forward_map_.find(op);
if (iter != op_forward_map_.end()) { if (iter != op_forward_map_.end()) {
return; return;
} }
op_forward_map_[op] = value; op_forward_map_[id] = value;
++op_id_map_[id]; MS_LOG(DEBUG) << "Save op forward value: "
MS_LOG(DEBUG) << "Save: " << op_exec_info->op_name << "(" << op << "), " << value; << "(" << id << "), " << value;
} }
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
@ -748,9 +744,25 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN
} }
auto value = PyAttrValue(out_real); auto value = PyAttrValue(out_real);
if (cnode != nullptr) { if (cnode != nullptr) {
cnode->set_forward(value); size_t size = op_exec_info->op_inputs.size();
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
auto obj_id = GetId(obj);
if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) {
cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]);
} else {
cnode->add_input_value(nullptr, "");
}
}
std::string id = GetOpId(op_exec_info);
int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int>();
auto op_id = std::to_string(graph_id) + id;
op_id.append(std::to_string(op_id_map_[id]));
cnode->set_forward(value, op_id);
++op_id_map_[id];
auto out_id = GetId(out_real);
obj_to_forward_id_[out_id] = op_id;
} }
SaveOpForwardValue(op_exec_info, value);
} }
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
@ -775,7 +787,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
node_abs_map_[id] = node->abstract(); node_abs_map_[id] = node->abstract();
} }
MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
node->cast<CNodePtr>()->set_forward(PyAttrValue(obj)); node->cast<CNodePtr>()->set_forward(PyAttrValue(obj), "");
return node; return node;
} }
@ -1131,6 +1143,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
} }
} }
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
graph_info_map_.erase(curr_g_);
if (curr_g_ != top_g_) { if (curr_g_ != top_g_) {
Popp(); Popp();
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
@ -1300,6 +1313,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
curr_g_ = nullptr; curr_g_ = nullptr;
graph_info_map_.clear(); graph_info_map_.clear();
op_id_map_.clear(); op_id_map_.clear();
obj_to_forward_id_.clear();
std::stack<FuncGraphPtr>().swap(graph_p_); std::stack<FuncGraphPtr>().swap(graph_p_);
ConfigManager::GetInstance().ResetIterNum(); ConfigManager::GetInstance().ResetIterNum();
} }

@ -108,7 +108,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
abstract::AbstractBasePtrList *args_spec_list); abstract::AbstractBasePtrList *args_spec_list);
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode); void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
void SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value); void SaveOpForwardValue(const std::string &id, const ValuePtr &value);
void SaveForwardResult(const CNodePtr &cnode, const py::object &out); void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
@ -138,6 +138,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, ValuePtr> op_forward_map_; std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_; std::unordered_map<std::string, size_t> op_id_map_;
std::unordered_map<std::string, std::string> obj_to_forward_id_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_; std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::stack<FuncGraphPtr> graph_p_; std::stack<FuncGraphPtr> graph_p_;
FuncGraphPtr top_g_; FuncGraphPtr top_g_;

@ -31,7 +31,7 @@
namespace mindspore { namespace mindspore {
// namespace to support intermediate representation definition // namespace to support intermediate representation definition
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false), output_value_(std::make_pair(nullptr, "")) {}
// Check if CNode is an apply with the specific Primitive. // Check if CNode is an apply with the specific Primitive.
bool CNode::IsApply(const PrimitivePtr &value) const { bool CNode::IsApply(const PrimitivePtr &value) const {

@ -232,8 +232,15 @@ class CNode : public AnfNode {
void set_input(size_t i, const AnfNodePtr &input); void set_input(size_t i, const AnfNodePtr &input);
void set_inputs(const std::vector<AnfNodePtr> &inputs) { inputs_ = inputs; } void set_inputs(const std::vector<AnfNodePtr> &inputs) { inputs_ = inputs; }
void set_forward(const ValuePtr &forward) { forward_ = forward; } void add_input_value(const ValuePtr &input_value, const std::string &id) {
const ValuePtr &forward() const { return forward_; } inputs_value_.push_back(std::make_pair(input_value, id));
}
void clear_inputs_value() { inputs_value_.clear(); }
void set_inputs_value(const std::vector<std::pair<ValuePtr, std::string>> &values) { inputs_value_ = values; }
const std::vector<std::pair<ValuePtr, std::string>> &inputs_value() const { return inputs_value_; }
void set_forward(const ValuePtr &forward, const std::string &id) { output_value_ = std::make_pair(forward, id); }
const std::pair<ValuePtr, std::string> &forward() const { return output_value_; }
bool stop_gradient() const { return stop_gradient_; } bool stop_gradient() const { return stop_gradient_; }
void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
@ -253,7 +260,10 @@ class CNode : public AnfNode {
VarPtr func_graph_as_var_; VarPtr func_graph_as_var_;
bool stop_gradient_; bool stop_gradient_;
bool in_forward_flag_ = false; bool in_forward_flag_ = false;
ValuePtr forward_ = nullptr; // inputs_value_ store cnode input value and id in pynative mode
// output_value_ store cnode value and id in pynative mode
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
std::pair<ValuePtr, std::string> output_value_;
}; };
// ANode represents the atomic node. It's derived Parameter and ValueNode. // ANode represents the atomic node. It's derived Parameter and ValueNode.

@ -88,7 +88,8 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target); CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
auto old_node = node->cast<CNodePtr>(); auto old_node = node->cast<CNodePtr>();
new_node->set_abstract(old_node->abstract()); new_node->set_abstract(old_node->abstract());
new_node->set_forward(old_node->forward()); new_node->set_forward(old_node->forward().first, old_node->forward().second);
new_node->set_inputs_value(old_node->inputs_value());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope); new_node->set_scope(scope);
new_node->set_kernel_info(old_node->kernel_info_ptr()); new_node->set_kernel_info(old_node->kernel_info_ptr());

Loading…
Cancel
Save