|
|
|
@ -621,13 +621,14 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
auto op_name = py::cast<std::string>(args[PY_NAME]);
|
|
|
|
|
op_exec_info->op_name = op_name;
|
|
|
|
|
if (grad_flag()) {
|
|
|
|
|
auto resource = GetResource();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(resource);
|
|
|
|
|
MS_LOG(DEBUG) << "Get resource ptr " << resource.get();
|
|
|
|
|
int64_t graph_id = graph_id_;
|
|
|
|
|
auto it = resource->results().find(pipeline::kPynativeGraphId);
|
|
|
|
|
if (it != resource->results().end()) {
|
|
|
|
|
graph_id = it->second.cast<int64_t>();
|
|
|
|
|
auto resource = GetResource();
|
|
|
|
|
if (resource != nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "Get resource ptr " << resource.get();
|
|
|
|
|
auto it = resource->results().find(pipeline::kPynativeGraphId);
|
|
|
|
|
if (it != resource->results().end()) {
|
|
|
|
|
graph_id = it->second.cast<int64_t>();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]);
|
|
|
|
|
op_index_map_[op_name]++;
|
|
|
|
@ -686,7 +687,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|
|
|
|
|
|
|
|
|
if (need_construct_graph()) {
|
|
|
|
|
AnfNodePtr input_node = nullptr;
|
|
|
|
|
if (!graph_info_map_.empty()) {
|
|
|
|
|
if (!graph_info_map_.empty() && !top_cell_list_.empty()) {
|
|
|
|
|
input_node = GetInput(obj, op_mask);
|
|
|
|
|
}
|
|
|
|
|
// update abstract
|
|
|
|
@ -1450,7 +1451,7 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) {
|
|
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
|
|
|
|
|
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
|
|
|
|
|
return !value.bprop_cell_id.empty() && cell_id.find(value.bprop_cell_id) != std::string::npos;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1466,6 +1467,16 @@ bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad)
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::ClearResidualRes() {
|
|
|
|
|
if (top_cell_list_.empty() && !graph_stack_.empty()) {
|
|
|
|
|
graph_id_ = 0;
|
|
|
|
|
graph_info_map_.clear();
|
|
|
|
|
cell_sw_map_.clear();
|
|
|
|
|
cell_graph_list_.clear();
|
|
|
|
|
top_cell_list_.clear();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
|
|
|
|
|
// Cell is empty, get nearest dfbuilder
|
|
|
|
|
if (cell_id.empty() && !top_cell_list_.empty()) {
|
|
|
|
@ -1490,6 +1501,10 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
|
|
|
|
|
return it.df_builder;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Current cell is not top graph, get first top cell
|
|
|
|
|
if (!top_cell_list_.empty()) {
|
|
|
|
|
return top_cell_list_.front().df_builder;
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1703,7 +1718,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|
|
|
|
// init resource for constructing forward graph and grad graph
|
|
|
|
|
auto g = std::make_shared<FuncGraph>();
|
|
|
|
|
curr_g_ = g;
|
|
|
|
|
if (graph_stack_.empty()) {
|
|
|
|
|
ClearResidualRes();
|
|
|
|
|
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
|
|
|
|
|
MakeNewTopGraph(cell_id, args, g);
|
|
|
|
|
}
|
|
|
|
|
PushCurrentGraphToStack();
|
|
|
|
@ -1724,16 +1740,6 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|
|
|
|
dynamic_cell_ = IsDynamicCell(cell);
|
|
|
|
|
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_;
|
|
|
|
|
}
|
|
|
|
|
// Make bprop graph
|
|
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
|
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
|
|
|
|
|
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
|
|
|
|
|
});
|
|
|
|
|
if (it != cell_graph_list_.end()) {
|
|
|
|
|
MS_LOG(INFO) << "Make bprop graph";
|
|
|
|
|
it->custom_bprop_graph = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) {
|
|
|
|
@ -1807,7 +1813,7 @@ void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, con
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
|
|
|
|
|
auto cell_id = GetCellId(cell, args);
|
|
|
|
|
const auto &cell_id = GetCellId(cell, args);
|
|
|
|
|
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
|
|
|
|
|
if (!dynamic_cell_ && graph_stack_.empty() && CheckCellGraph(cell_id)) {
|
|
|
|
|
MS_LOG(INFO) << "Endgraph already compiled";
|
|
|
|
@ -1841,35 +1847,30 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
|
|
|
|
|
AnfNodePtr output_node = GetObjNode(out, out_id);
|
|
|
|
|
curr_g_->set_output(output_node);
|
|
|
|
|
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
|
|
|
|
|
if (EndBpropGraph(cell_id)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto resource = GetResource(cell_id);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(resource);
|
|
|
|
|
auto is_bprop_graph = IsBpropGraph(cell_id);
|
|
|
|
|
auto is_bprop_cell = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
|
|
|
|
|
if (!is_bprop_cell || !is_bprop_graph) {
|
|
|
|
|
resource->manager()->AddFuncGraph(curr_g_);
|
|
|
|
|
}
|
|
|
|
|
if (!is_bprop_cell) {
|
|
|
|
|
UpdateCellGraph(cell, curr_g_, cell_id, true, false);
|
|
|
|
|
}
|
|
|
|
|
resource->manager()->AddFuncGraph(curr_g_);
|
|
|
|
|
UpdateCellGraph(cell, curr_g_, cell_id, true, false);
|
|
|
|
|
auto newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args);
|
|
|
|
|
|
|
|
|
|
if (graph_stack_.size() > 1) {
|
|
|
|
|
if (!is_bprop_cell || !is_bprop_graph) {
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.emplace_back(NewValueNode(curr_g_));
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.emplace_back(NewValueNode(curr_g_));
|
|
|
|
|
|
|
|
|
|
PopGraphStack();
|
|
|
|
|
// connect the previous graph to the inside graph
|
|
|
|
|
auto graph_prev = graph_stack_.top();
|
|
|
|
|
for (size_t i = 0; i < args.size(); i++) {
|
|
|
|
|
auto input = GetInput(args[i], false);
|
|
|
|
|
inputs.emplace_back(input);
|
|
|
|
|
}
|
|
|
|
|
auto out_cnode = graph_prev->NewCNode(inputs);
|
|
|
|
|
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
|
|
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
|
|
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
|
|
|
|
|
PopGraphStack();
|
|
|
|
|
// connect the previous graph to the inside graph
|
|
|
|
|
auto graph_prev = graph_stack_.top();
|
|
|
|
|
for (size_t i = 0; i < args.size(); i++) {
|
|
|
|
|
auto input = GetInput(args[i], false);
|
|
|
|
|
inputs.emplace_back(input);
|
|
|
|
|
}
|
|
|
|
|
auto out_cnode = graph_prev->NewCNode(inputs);
|
|
|
|
|
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
|
|
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
|
|
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
|
|
|
|
|
} else {
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("before_resolve.ir", newfg);
|
|
|
|
@ -1883,6 +1884,17 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::EndBpropGraph(const string &cell_id) {
|
|
|
|
|
auto is_bprop_graph = IsBpropGraph(cell_id);
|
|
|
|
|
if (is_bprop_graph) {
|
|
|
|
|
if (IsNotNestedGrad()) {
|
|
|
|
|
PopGraphStack();
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
|
|
|
|
bool need_cloned, bool is_grad) {
|
|
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
|
|
|
@ -1894,7 +1906,8 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
|
|
|
|
|
it->fg = g;
|
|
|
|
|
MS_LOG(DEBUG) << "Update bprop bg";
|
|
|
|
|
} else {
|
|
|
|
|
auto cell_info = CellInfo(false, true, false, g, cell_id);
|
|
|
|
|
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
|
|
|
|
|
auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func));
|
|
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
@ -1923,13 +1936,13 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Add new cell graph " << cell_id;
|
|
|
|
|
auto cell_info = CellInfo(false, true, false, tmp, cell_id);
|
|
|
|
|
auto cell_info = CellInfo(false, true, tmp, cell_id, "");
|
|
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
|
|
|
|
|
const string &cell_id, const py::args &args) {
|
|
|
|
|
bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !IsBpropGraph(cell_id);
|
|
|
|
|
const std::string &cell_id, const py::args &args) {
|
|
|
|
|
bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
|
|
|
|
|
if (is_custom_bprop) {
|
|
|
|
|
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
|
|
|
|
|
if (par_number > 0) {
|
|
|
|
@ -1943,17 +1956,15 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG
|
|
|
|
|
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
FuncGraphPtr newfg = nullptr;
|
|
|
|
|
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) || is_custom_bprop) {
|
|
|
|
|
// Obtain grad graph
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("fg.ir", g);
|
|
|
|
|
}
|
|
|
|
|
auto is_top = IsTopGraph(cell_id);
|
|
|
|
|
MS_LOG(DEBUG) << "Grad top cell " << is_top;
|
|
|
|
|
set_need_replace_forward(IsNotNestedGrad());
|
|
|
|
|
newfg = ad::Grad(g, r, is_top);
|
|
|
|
|
// Obtain grad graph
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
|
DumpIR("fg.ir", g);
|
|
|
|
|
}
|
|
|
|
|
auto is_top = IsTopGraph(cell_id);
|
|
|
|
|
MS_LOG(DEBUG) << "Grad top cell " << is_top;
|
|
|
|
|
set_need_replace_forward(IsNotNestedGrad());
|
|
|
|
|
auto newfg = ad::Grad(g, r, is_top);
|
|
|
|
|
|
|
|
|
|
if (is_custom_bprop) {
|
|
|
|
|
auto params = newfg->parameters();
|
|
|
|
|
auto manager = Manage({newfg}, false);
|
|
|
|
@ -2135,7 +2146,7 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
|
|
|
|
|
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Get wights params size " << weights_params.size();
|
|
|
|
|
MS_LOG(DEBUG) << "Get weights params size " << weights_params.size();
|
|
|
|
|
df_builder->set_parameters(weights_params);
|
|
|
|
|
resource->manager()->AddFuncGraph(forward_graph);
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
|
|
|
@ -2314,7 +2325,7 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
|
|
|
|
|
MS_LOG(DEBUG) << "Grad not running yet";
|
|
|
|
|
return BaseRefToPyData(ret);
|
|
|
|
|
}
|
|
|
|
|
auto cell_id = GetCellId(cell, args);
|
|
|
|
|
const auto &cell_id = GetCellId(cell, args);
|
|
|
|
|
string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size()));
|
|
|
|
|
MS_LOG(DEBUG) << "Key is " << key;
|
|
|
|
|
for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) {
|
|
|
|
@ -2379,12 +2390,6 @@ bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::ob
|
|
|
|
|
MS_LOG(DEBUG) << "No nested bprop grad find";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
|
|
|
|
|
return value.custom_bprop_graph && value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
|
|
|
|
|
});
|
|
|
|
|
if (it != cell_graph_list_.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Make bprop graph end";
|
|
|
|
|
}
|
|
|
|
|
auto out_id = GetId(out);
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.emplace_back(NewValueNode(curr_g_));
|
|
|
|
@ -2489,6 +2494,7 @@ void PynativeExecutor::Clean() {
|
|
|
|
|
void PynativeExecutor::ClearRes() {
|
|
|
|
|
MS_LOG(DEBUG) << "Clear all res";
|
|
|
|
|
Clean();
|
|
|
|
|
graph_id_ = 0;
|
|
|
|
|
grad_order_ = 0;
|
|
|
|
|
grad_flag_ = false;
|
|
|
|
|
dynamic_cell_ = false;
|
|
|
|
|