|
|
|
@ -1439,6 +1439,12 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
|
|
|
|
|
return cell_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR(filename, graph);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsNotNestedGrad() const {
|
|
|
|
|
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_;
|
|
|
|
|
return grad_order_ <= 1;
|
|
|
|
@ -1851,6 +1857,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
|
|
|
|
|
curr_g_->set_output(output_node);
|
|
|
|
|
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
|
|
|
|
|
if (EndBpropGraph(cell_id)) {
|
|
|
|
|
MS_LOG(DEBUG) << "Get bprop function cell";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto resource = GetResource(cell_id);
|
|
|
|
@ -1875,13 +1882,9 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
|
|
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
|
|
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
|
|
|
|
|
} else {
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("before_resolve.ir", newfg);
|
|
|
|
|
}
|
|
|
|
|
DumpGraphIR("before_resolve.ir", newfg);
|
|
|
|
|
parse::ResolveFuncGraph(newfg, resource);
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("after_resolve.ir", newfg);
|
|
|
|
|
}
|
|
|
|
|
DumpGraphIR("after_resolve.ir", newfg);
|
|
|
|
|
resource->set_func_graph(newfg);
|
|
|
|
|
PopGraphStack();
|
|
|
|
|
}
|
|
|
|
@ -1907,10 +1910,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
|
|
|
|
|
if (it != cell_graph_list_.end()) {
|
|
|
|
|
it->is_grad = is_grad;
|
|
|
|
|
it->fg = g;
|
|
|
|
|
MS_LOG(DEBUG) << "Update bprop bg";
|
|
|
|
|
MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id;
|
|
|
|
|
} else {
|
|
|
|
|
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
|
|
|
|
|
auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func));
|
|
|
|
|
auto bprop_func_cell_id = GetId(bprop_func);
|
|
|
|
|
MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id;
|
|
|
|
|
auto cell_info = CellInfo(false, true, g, cell_id, bprop_func_cell_id);
|
|
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
@ -1959,13 +1964,11 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG
|
|
|
|
|
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Obtain grad graph
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("fg.ir", g);
|
|
|
|
|
}
|
|
|
|
|
DumpGraphIR("fg.ir", g);
|
|
|
|
|
auto is_top = IsTopGraph(cell_id);
|
|
|
|
|
MS_LOG(DEBUG) << "Grad top cell " << is_top;
|
|
|
|
|
set_need_replace_forward(IsNotNestedGrad());
|
|
|
|
|
// Obtain grad graph
|
|
|
|
|
auto newfg = ad::Grad(g, r, is_top);
|
|
|
|
|
|
|
|
|
|
if (is_custom_bprop) {
|
|
|
|
@ -2039,11 +2042,9 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
|
|
|
|
|
auto args_spec = GetArgsSpec(args, df_builder);
|
|
|
|
|
resource->set_args_spec(args_spec);
|
|
|
|
|
// Get real grad graph
|
|
|
|
|
DumpGraphIR("before_grad.ir", resource->func_graph());
|
|
|
|
|
GradGraph(resource->func_graph(), grad, w_args, size, cell_id);
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("before_grad.ir", resource->func_graph());
|
|
|
|
|
DumpIR("after_grad.ir", df_builder);
|
|
|
|
|
}
|
|
|
|
|
DumpGraphIR("after_grad.ir", df_builder);
|
|
|
|
|
resource->set_func_graph(df_builder);
|
|
|
|
|
resource->manager()->KeepRoots({df_builder});
|
|
|
|
|
resource->results()[pipeline::kBackend] = compile::CreateBackend();
|
|
|
|
@ -2127,30 +2128,35 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(forward_graph);
|
|
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("nested_bprop.ir", forward_graph);
|
|
|
|
|
}
|
|
|
|
|
DumpGraphIR("nested_bprop.ir", forward_graph);
|
|
|
|
|
// Custom bprop get backward graph(before opt), which use like other forward graph
|
|
|
|
|
curr_g_ = forward_graph;
|
|
|
|
|
resource->set_func_graph(forward_graph);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copy weights
|
|
|
|
|
std::vector<AnfNodePtr> weights_params{};
|
|
|
|
|
// Copy weights parameters
|
|
|
|
|
resource->manager()->AddFuncGraph(forward_graph);
|
|
|
|
|
auto manager = Manage({forward_graph}, false);
|
|
|
|
|
for (const auto &it : graph_info_map_.at(forward_graph).params) {
|
|
|
|
|
if (it.second->has_default()) {
|
|
|
|
|
weights_params.emplace_back(it.second);
|
|
|
|
|
graph_info_map_.at(df_builder).params.emplace(it.first, it.second);
|
|
|
|
|
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second);
|
|
|
|
|
if (!it.second->has_default()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Get weights params size " << weights_params.size();
|
|
|
|
|
df_builder->set_parameters(weights_params);
|
|
|
|
|
resource->manager()->AddFuncGraph(forward_graph);
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("nested_fg.ir", forward_graph);
|
|
|
|
|
}
|
|
|
|
|
auto new_param = df_builder->add_parameter();
|
|
|
|
|
new_param->set_abstract(it.second->abstract());
|
|
|
|
|
new_param->set_name(it.second->name());
|
|
|
|
|
new_param->set_default_param(it.second->default_param());
|
|
|
|
|
ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope;
|
|
|
|
|
new_param->set_scope(scope);
|
|
|
|
|
manager->Replace(it.second, new_param);
|
|
|
|
|
replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param));
|
|
|
|
|
MS_LOG(DEBUG) << "Old param ptr " << it.second.get() << " name " << it.second->name();
|
|
|
|
|
|
|
|
|
|
graph_info_map_.at(df_builder).params[it.first] = new_param;
|
|
|
|
|
SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param);
|
|
|
|
|
SetNodeMapInGraphInfoMap(df_builder, it.first, new_param);
|
|
|
|
|
}
|
|
|
|
|
DumpGraphIR("nested_fg.ir", forward_graph);
|
|
|
|
|
set_need_replace_forward(false);
|
|
|
|
|
auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args);
|
|
|
|
|
resource->set_func_graph(newfg);
|
|
|
|
@ -2396,15 +2402,18 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
|
|
|
|
|
MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get();
|
|
|
|
|
auto newfg = resource->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(newfg);
|
|
|
|
|
auto size = args.size();
|
|
|
|
|
auto inputs_size = args.size();
|
|
|
|
|
if (has_sens) {
|
|
|
|
|
size -= 1;
|
|
|
|
|
inputs_size -= 1;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.emplace_back(NewValueNode(newfg));
|
|
|
|
|
for (size_t i = 0; i < size; ++i) {
|
|
|
|
|
for (size_t i = 0; i < inputs_size; ++i) {
|
|
|
|
|
inputs.emplace_back(GetInput(args[i], false));
|
|
|
|
|
}
|
|
|
|
|
if (newfg->parameters().size() > inputs_size) {
|
|
|
|
|
SetNestedWeigthsParam(newfg, cell_id, &inputs);
|
|
|
|
|
}
|
|
|
|
|
auto out_id = GetId(out);
|
|
|
|
|
auto cnode = graph_prev->NewCNode(inputs);
|
|
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, cnode);
|
|
|
|
@ -2412,6 +2421,38 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
|
|
|
|
|
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id,
|
|
|
|
|
std::vector<AnfNodePtr> *inputs) {
|
|
|
|
|
FuncGraphPtr forward_graph = nullptr;
|
|
|
|
|
auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
|
|
|
|
|
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
|
|
|
|
|
if (ic != cell_graph_list_.end()) {
|
|
|
|
|
forward_graph = ic->fg;
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(forward_graph);
|
|
|
|
|
auto params = newfg->parameters();
|
|
|
|
|
auto manage = Manage({newfg}, false);
|
|
|
|
|
for (const auto &it : params) {
|
|
|
|
|
auto param = it->cast<ParameterPtr>();
|
|
|
|
|
if (!param->has_default()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto ir = replace_weights_map_.find(forward_graph);
|
|
|
|
|
if (ir == replace_weights_map_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Not find forward_graph in repalce weigths map";
|
|
|
|
|
}
|
|
|
|
|
for (const auto &ip : ir->second) {
|
|
|
|
|
MS_LOG(DEBUG) << "Get param name " << param->name() << " cache name " << ip.second->name();
|
|
|
|
|
if (ip.second->name() == param->name()) {
|
|
|
|
|
manage->Replace(param, ip.first);
|
|
|
|
|
inputs->emplace_back(ip.first);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
replace_weights_map_.erase(forward_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::Clear(const std::string &cell_id) {
|
|
|
|
|
if (cell_id.empty()) {
|
|
|
|
|
Clean();
|
|
|
|
@ -2461,6 +2502,7 @@ void PynativeExecutor::ClearRes() {
|
|
|
|
|
|
|
|
|
|
graph_info_map_.clear();
|
|
|
|
|
cell_sw_map_.clear();
|
|
|
|
|
replace_weights_map_.clear();
|
|
|
|
|
cell_graph_list_.clear();
|
|
|
|
|
top_cell_list_.clear();
|
|
|
|
|
op_index_map_.clear();
|
|
|
|
|