!10160 Opitimize pynative bprop second derivate

From: @zjun3021
Reviewed-by: @chujinjin,@zhoufeng54
Signed-off-by: @chujinjin
pull/10160/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 09dcefffaa

@ -46,7 +46,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
auto bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", obj);
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));

@ -621,13 +621,14 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
auto op_name = py::cast<std::string>(args[PY_NAME]);
op_exec_info->op_name = op_name;
if (grad_flag()) {
auto resource = GetResource();
MS_EXCEPTION_IF_NULL(resource);
MS_LOG(DEBUG) << "Get resource ptr " << resource.get();
int64_t graph_id = graph_id_;
auto it = resource->results().find(pipeline::kPynativeGraphId);
if (it != resource->results().end()) {
graph_id = it->second.cast<int64_t>();
auto resource = GetResource();
if (resource != nullptr) {
MS_LOG(DEBUG) << "Get resource ptr " << resource.get();
auto it = resource->results().find(pipeline::kPynativeGraphId);
if (it != resource->results().end()) {
graph_id = it->second.cast<int64_t>();
}
}
op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]);
op_index_map_[op_name]++;
@ -686,7 +687,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
if (need_construct_graph()) {
AnfNodePtr input_node = nullptr;
if (!graph_info_map_.empty()) {
if (!graph_info_map_.empty() && !top_cell_list_.empty()) {
input_node = GetInput(obj, op_mask);
}
// update abstract
@ -1450,7 +1451,7 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) {
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
return !value.bprop_cell_id.empty() && cell_id.find(value.bprop_cell_id) != std::string::npos;
});
}
@ -1466,6 +1467,16 @@ bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad)
});
}
void PynativeExecutor::ClearResidualRes() {
if (top_cell_list_.empty() && !graph_stack_.empty()) {
graph_id_ = 0;
graph_info_map_.clear();
cell_sw_map_.clear();
cell_graph_list_.clear();
top_cell_list_.clear();
}
}
FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
// Cell is empty, get nearest dfbuilder
if (cell_id.empty() && !top_cell_list_.empty()) {
@ -1490,6 +1501,10 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
return it.df_builder;
}
}
// Current cell is not top graph, get first top cell
if (!top_cell_list_.empty()) {
return top_cell_list_.front().df_builder;
}
return nullptr;
}
@ -1703,7 +1718,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
// init resource for constructing forward graph and grad graph
auto g = std::make_shared<FuncGraph>();
curr_g_ = g;
if (graph_stack_.empty()) {
ClearResidualRes();
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
MakeNewTopGraph(cell_id, args, g);
}
PushCurrentGraphToStack();
@ -1724,16 +1740,6 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
dynamic_cell_ = IsDynamicCell(cell);
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_;
}
// Make bprop graph
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
});
if (it != cell_graph_list_.end()) {
MS_LOG(INFO) << "Make bprop graph";
it->custom_bprop_graph = true;
}
}
}
void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) {
@ -1807,7 +1813,7 @@ void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, con
}
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
auto cell_id = GetCellId(cell, args);
const auto &cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
if (!dynamic_cell_ && graph_stack_.empty() && CheckCellGraph(cell_id)) {
MS_LOG(INFO) << "Endgraph already compiled";
@ -1841,35 +1847,30 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
AnfNodePtr output_node = GetObjNode(out, out_id);
curr_g_->set_output(output_node);
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
if (EndBpropGraph(cell_id)) {
return;
}
auto resource = GetResource(cell_id);
MS_EXCEPTION_IF_NULL(resource);
auto is_bprop_graph = IsBpropGraph(cell_id);
auto is_bprop_cell = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
if (!is_bprop_cell || !is_bprop_graph) {
resource->manager()->AddFuncGraph(curr_g_);
}
if (!is_bprop_cell) {
UpdateCellGraph(cell, curr_g_, cell_id, true, false);
}
resource->manager()->AddFuncGraph(curr_g_);
UpdateCellGraph(cell, curr_g_, cell_id, true, false);
auto newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args);
if (graph_stack_.size() > 1) {
if (!is_bprop_cell || !is_bprop_graph) {
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(curr_g_));
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(curr_g_));
PopGraphStack();
// connect the previous graph to the inside graph
auto graph_prev = graph_stack_.top();
for (size_t i = 0; i < args.size(); i++) {
auto input = GetInput(args[i], false);
inputs.emplace_back(input);
}
auto out_cnode = graph_prev->NewCNode(inputs);
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
PopGraphStack();
// connect the previous graph to the inside graph
auto graph_prev = graph_stack_.top();
for (size_t i = 0; i < args.size(); i++) {
auto input = GetInput(args[i], false);
inputs.emplace_back(input);
}
auto out_cnode = graph_prev->NewCNode(inputs);
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
} else {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("before_resolve.ir", newfg);
@ -1883,6 +1884,17 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
}
}
bool PynativeExecutor::EndBpropGraph(const string &cell_id) {
auto is_bprop_graph = IsBpropGraph(cell_id);
if (is_bprop_graph) {
if (IsNotNestedGrad()) {
PopGraphStack();
}
return true;
}
return false;
}
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
@ -1894,7 +1906,8 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
it->fg = g;
MS_LOG(DEBUG) << "Update bprop bg";
} else {
auto cell_info = CellInfo(false, true, false, g, cell_id);
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func));
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
}
return;
@ -1923,13 +1936,13 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
return;
}
MS_LOG(DEBUG) << "Add new cell graph " << cell_id;
auto cell_info = CellInfo(false, true, false, tmp, cell_id);
auto cell_info = CellInfo(false, true, tmp, cell_id, "");
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
}
FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
const string &cell_id, const py::args &args) {
bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !IsBpropGraph(cell_id);
const std::string &cell_id, const py::args &args) {
bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
if (is_custom_bprop) {
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
if (par_number > 0) {
@ -1943,17 +1956,15 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
}
}
FuncGraphPtr newfg = nullptr;
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) || is_custom_bprop) {
// Obtain grad graph
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("fg.ir", g);
}
auto is_top = IsTopGraph(cell_id);
MS_LOG(DEBUG) << "Grad top cell " << is_top;
set_need_replace_forward(IsNotNestedGrad());
newfg = ad::Grad(g, r, is_top);
// Obtain grad graph
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("fg.ir", g);
}
auto is_top = IsTopGraph(cell_id);
MS_LOG(DEBUG) << "Grad top cell " << is_top;
set_need_replace_forward(IsNotNestedGrad());
auto newfg = ad::Grad(g, r, is_top);
if (is_custom_bprop) {
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
@ -2135,7 +2146,7 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second);
}
}
MS_LOG(DEBUG) << "Get wights params size " << weights_params.size();
MS_LOG(DEBUG) << "Get weights params size " << weights_params.size();
df_builder->set_parameters(weights_params);
resource->manager()->AddFuncGraph(forward_graph);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
@ -2314,7 +2325,7 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
MS_LOG(DEBUG) << "Grad not running yet";
return BaseRefToPyData(ret);
}
auto cell_id = GetCellId(cell, args);
const auto &cell_id = GetCellId(cell, args);
string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size()));
MS_LOG(DEBUG) << "Key is " << key;
for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) {
@ -2379,12 +2390,6 @@ bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::ob
MS_LOG(DEBUG) << "No nested bprop grad find";
return false;
}
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
return value.custom_bprop_graph && value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
});
if (it != cell_graph_list_.end()) {
MS_LOG(DEBUG) << "Make bprop graph end";
}
auto out_id = GetId(out);
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(curr_g_));
@ -2489,6 +2494,7 @@ void PynativeExecutor::Clean() {
void PynativeExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear all res";
Clean();
graph_id_ = 0;
grad_order_ = 0;
grad_flag_ = false;
dynamic_cell_ = false;

@ -68,18 +68,18 @@ struct GraphInfo {
};
struct CellInfo {
bool is_grad{false}; // Derivative is calculated
bool is_custom_bprop{false}; // Custom bprop
bool custom_bprop_graph{false}; // Custom bprop make forward graph
FuncGraphPtr fg; // Forward graph
bool is_grad{false}; // Derivative is calculated
bool is_custom_bprop{false}; // Custom bprop
FuncGraphPtr fg; // Forward graph
std::string cell_id;
std::string bprop_cell_id;
CellInfo() = default;
CellInfo(bool isgrad, bool custom_bprop, bool bprop_graph, FuncGraphPtr foward_graph, std::string cellid)
CellInfo(bool isgrad, bool custom_bprop, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
: is_grad(isgrad),
is_custom_bprop(custom_bprop),
custom_bprop_graph(bprop_graph),
fg(std::move(foward_graph)),
cell_id(std::move(cellid)) {}
cell_id(std::move(cellid)),
bprop_cell_id(std::move(bprop_id)) {}
};
struct TopCellInfo {
@ -187,13 +187,15 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
void ClearResidualRes();
void NewGraphInner(const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
const std::string &out_id, const py::args &args);
FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, const string &cell_id,
const py::args &args);
bool EndBpropGraph(const string &cell_id);
FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
const std::string &cell_id, const py::args &args);
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
py::object *sens = nullptr);
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,

@ -182,21 +182,16 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
auto inst = pynative::PynativeExecutor::GetInstance();
MS_EXCEPTION_IF_NULL(inst);
try {
inst->NewGraph(GetPyObj(), input_args.cast<py::args>());
MS_LOG(DEBUG) << "Run bprop function start";
inst->NewGraph(hook_, input_args.cast<py::args>());
py::object grads_obj = hook_(*convert_args);
py::tuple grads = check_bprop_out(grads_obj, py_args);
inst->EndGraph(GetPyObj(), grads_obj, input_args.cast<py::args>());
inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
MS_LOG(DEBUG) << "Run bprop function end";
return std::make_shared<PyObjectRef>(grads);
} catch (const py::type_error &ex) {
} catch (std::exception &bt) {
inst->ClearRes();
throw py::type_error(ex);
} catch (const py::value_error &ex) {
inst->ClearRes();
throw py::value_error(ex);
} catch (...) {
inst->ClearRes();
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(EXCEPTION) << "Error occurred in run bprop. Exception name: " << exName;
std::rethrow_exception(std::current_exception());
}
}
SyncData(py_args[2]);

Loading…
Cancel
Save