!10330 Fix pynative paramters second derivative

From: @zjun3021
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
pull/10330/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 695cdbbe69

@ -1439,6 +1439,12 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
return cell_id;
}
void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR(filename, graph);
}
}
bool PynativeExecutor::IsNotNestedGrad() const {
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_;
return grad_order_ <= 1;
@ -1851,6 +1857,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
curr_g_->set_output(output_node);
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
if (EndBpropGraph(cell_id)) {
MS_LOG(DEBUG) << "Get bprop function cell";
return;
}
auto resource = GetResource(cell_id);
@ -1875,13 +1882,9 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
} else {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("before_resolve.ir", newfg);
}
DumpGraphIR("before_resolve.ir", newfg);
parse::ResolveFuncGraph(newfg, resource);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("after_resolve.ir", newfg);
}
DumpGraphIR("after_resolve.ir", newfg);
resource->set_func_graph(newfg);
PopGraphStack();
}
@ -1907,10 +1910,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
if (it != cell_graph_list_.end()) {
it->is_grad = is_grad;
it->fg = g;
MS_LOG(DEBUG) << "Update bprop bg";
MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id;
} else {
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func));
auto bprop_func_cell_id = GetId(bprop_func);
MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id;
auto cell_info = CellInfo(false, true, g, cell_id, bprop_func_cell_id);
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
}
return;
@ -1959,13 +1964,11 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
}
}
// Obtain grad graph
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("fg.ir", g);
}
DumpGraphIR("fg.ir", g);
auto is_top = IsTopGraph(cell_id);
MS_LOG(DEBUG) << "Grad top cell " << is_top;
set_need_replace_forward(IsNotNestedGrad());
// Obtain grad graph
auto newfg = ad::Grad(g, r, is_top);
if (is_custom_bprop) {
@ -2039,11 +2042,9 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
auto args_spec = GetArgsSpec(args, df_builder);
resource->set_args_spec(args_spec);
// Get real grad graph
DumpGraphIR("before_grad.ir", resource->func_graph());
GradGraph(resource->func_graph(), grad, w_args, size, cell_id);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("before_grad.ir", resource->func_graph());
DumpIR("after_grad.ir", df_builder);
}
DumpGraphIR("after_grad.ir", df_builder);
resource->set_func_graph(df_builder);
resource->manager()->KeepRoots({df_builder});
resource->results()[pipeline::kBackend] = compile::CreateBackend();
@ -2127,30 +2128,35 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
}
MS_EXCEPTION_IF_NULL(forward_graph);
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("nested_bprop.ir", forward_graph);
}
DumpGraphIR("nested_bprop.ir", forward_graph);
// Custom bprop get backward graph(before opt), which use like other forward graph
curr_g_ = forward_graph;
resource->set_func_graph(forward_graph);
return;
}
// Copy weights
std::vector<AnfNodePtr> weights_params{};
// Copy weights parameters
resource->manager()->AddFuncGraph(forward_graph);
auto manager = Manage({forward_graph}, false);
for (const auto &it : graph_info_map_.at(forward_graph).params) {
if (it.second->has_default()) {
weights_params.emplace_back(it.second);
graph_info_map_.at(df_builder).params.emplace(it.first, it.second);
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second);
if (!it.second->has_default()) {
continue;
}
}
MS_LOG(DEBUG) << "Get weights params size " << weights_params.size();
df_builder->set_parameters(weights_params);
resource->manager()->AddFuncGraph(forward_graph);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("nested_fg.ir", forward_graph);
}
auto new_param = df_builder->add_parameter();
new_param->set_abstract(it.second->abstract());
new_param->set_name(it.second->name());
new_param->set_default_param(it.second->default_param());
ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope;
new_param->set_scope(scope);
manager->Replace(it.second, new_param);
replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param));
MS_LOG(DEBUG) << "Old param ptr " << it.second.get() << " name " << it.second->name();
graph_info_map_.at(df_builder).params[it.first] = new_param;
SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param);
SetNodeMapInGraphInfoMap(df_builder, it.first, new_param);
}
DumpGraphIR("nested_fg.ir", forward_graph);
set_need_replace_forward(false);
auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args);
resource->set_func_graph(newfg);
@ -2396,15 +2402,18 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get();
auto newfg = resource->func_graph();
MS_EXCEPTION_IF_NULL(newfg);
auto size = args.size();
auto inputs_size = args.size();
if (has_sens) {
size -= 1;
inputs_size -= 1;
}
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(newfg));
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < inputs_size; ++i) {
inputs.emplace_back(GetInput(args[i], false));
}
if (newfg->parameters().size() > inputs_size) {
SetNestedWeigthsParam(newfg, cell_id, &inputs);
}
auto out_id = GetId(out);
auto cnode = graph_prev->NewCNode(inputs);
SetTupleArgsToGraphInfoMap(graph_prev, out, cnode);
@ -2412,6 +2421,38 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
}
void PynativeExecutor::SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id,
std::vector<AnfNodePtr> *inputs) {
FuncGraphPtr forward_graph = nullptr;
auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
if (ic != cell_graph_list_.end()) {
forward_graph = ic->fg;
}
MS_EXCEPTION_IF_NULL(forward_graph);
auto params = newfg->parameters();
auto manage = Manage({newfg}, false);
for (const auto &it : params) {
auto param = it->cast<ParameterPtr>();
if (!param->has_default()) {
continue;
}
auto ir = replace_weights_map_.find(forward_graph);
if (ir == replace_weights_map_.end()) {
MS_LOG(EXCEPTION) << "Not find forward_graph in repalce weigths map";
}
for (const auto &ip : ir->second) {
MS_LOG(DEBUG) << "Get param name " << param->name() << " cache name " << ip.second->name();
if (ip.second->name() == param->name()) {
manage->Replace(param, ip.first);
inputs->emplace_back(ip.first);
break;
}
}
}
replace_weights_map_.erase(forward_graph);
}
void PynativeExecutor::Clear(const std::string &cell_id) {
if (cell_id.empty()) {
Clean();
@ -2461,6 +2502,7 @@ void PynativeExecutor::ClearRes() {
graph_info_map_.clear();
cell_sw_map_.clear();
replace_weights_map_.clear();
cell_graph_list_.clear();
top_cell_list_.clear();
op_index_map_.clear();

@ -60,7 +60,7 @@ void ClearPyNativeSession();
struct GraphInfo {
std::string cell_id;
AnfNodePtr output;
std::unordered_map<std::string, ParameterPtr> params; // hold input parameters and cell weigths
OrderedMap<std::string, ParameterPtr> params; // hold input parameters and cell weigths
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
std::vector<std::string> objects;
GraphInfo() = default;
@ -210,6 +210,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
void ClearResidualRes(const std::string &cell_id);
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
void NewGraphInner(const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
@ -233,6 +234,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
const py::object &out, bool has_sens);
void SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
// Hold graph(forward and grad) info
@ -242,7 +244,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
bool is_param = false);
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) {
graph_info_map_[g].params.emplace(std::make_pair(id, param));
graph_info_map_[g].params[id] = param;
}
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
int64_t index = -1) {
@ -269,15 +271,16 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// Records forwrad graph, the bottom is top graph
std::stack<FuncGraphPtr> graph_stack_;
// Use vector for keep order
std::vector<CellInfo> cell_graph_list_;
std::vector<TopCellInfo> top_cell_list_;
std::unordered_set<std::string> cell_input_args_;
std::unordered_map<std::string, bool> cell_dynamic_map_;
// Record all info for all cells
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
// Use vector for keep order
std::vector<CellInfo> cell_graph_list_;
std::vector<TopCellInfo> top_cell_list_;
// key: cell_id, value: (send_id, weighs_id), cache for sens and weight change
std::unordered_map<std::string, std::pair<std::string, std::string>> cell_sw_map_;
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;
// Used for runop and replace forward result of grad graph
std::unordered_map<std::string, size_t> op_index_map_;

Loading…
Cancel
Save