From ae3a9d28830a12d1d61ac566db72284fb51b42db Mon Sep 17 00:00:00 2001 From: zjun Date: Fri, 22 Jan 2021 12:20:58 +0800 Subject: [PATCH] Reactor pynative Signed-off-by: zjun Modify real dynamic Signed-off-by: zjun --- .../pipeline/pynative/pynative_execute.cc | 1331 +++++++++-------- .../pipeline/pynative/pynative_execute.h | 513 ++++--- mindspore/common/api.py | 7 +- mindspore/nn/cell.py | 2 +- mindspore/ops/composite/base.py | 2 +- .../ut/cpp/pynative/pynative_execute_test.cc | 2 +- 6 files changed, 1030 insertions(+), 827 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index c156c9d67e..981f300f51 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -74,14 +74,17 @@ static const char kMSDtypeModelName[] = "mindspore.common.dtype"; namespace mindspore::pynative { static std::shared_ptr session = nullptr; PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; +ForwardExecutorPtr PynativeExecutor::forward_executor_ = nullptr; +GradExecutorPtr PynativeExecutor::grad_executor_ = nullptr; std::mutex PynativeExecutor::instance_lock_; -int64_t PynativeExecutor::graph_id_ = 0; -template -void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) { - MS_EXCEPTION_IF_NULL(executor); +template +void PynativeExecutorTry(std::function method, T *ret, const Args &... args) { + const auto inst = PynativeExecutor::GetInstance(); + MS_EXCEPTION_IF_NULL(inst); + MS_EXCEPTION_IF_NULL(method); try { - (executor->*method)(args...); + method(ret, args...); } catch (const py::error_already_set &ex) { // print function call stack info before release std::ostringstream oss; @@ -91,24 +94,24 @@ void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecuto // these info from screen, no need to open log file to find these info py::print(oss.str()); MS_LOG(ERROR) << oss.str(); - PynativeExecutor::GetInstance()->ClearRes(); + inst->ClearRes(); // re-throw this exception to Python interpreter to handle it throw(py::error_already_set(ex)); } catch (const py::type_error &ex) { - PynativeExecutor::GetInstance()->ClearRes(); + inst->ClearRes(); throw py::type_error(ex); } catch (const py::value_error &ex) { - PynativeExecutor::GetInstance()->ClearRes(); + inst->ClearRes(); throw py::value_error(ex); } catch (const py::index_error &ex) { - PynativeExecutor::GetInstance()->ClearRes(); + inst->ClearRes(); throw py::index_error(ex); } catch (const std::exception &ex) { - PynativeExecutor::GetInstance()->ClearRes(); + inst->ClearRes(); // re-throw this exception to Python interpreter to handle it throw(std::runtime_error(ex.what())); } catch (...) { - PynativeExecutor::GetInstance()->ClearRes(); + inst->ClearRes(); std::string exName(abi::__cxa_current_exception_type()->name()); MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; } @@ -553,11 +556,6 @@ py::tuple ConvertArgs(const py::tuple &args) { void ClearPyNativeSession() { session = nullptr; } -PynativeExecutor::~PynativeExecutor() { - MS_LOG(DEBUG) << "PynativeExecutor destructor"; - ClearRes(); -} - void CheckPyNativeContext() { auto parallel_context = parallel::ParallelContext::GetInstance(); MS_EXCEPTION_IF_NULL(parallel_context); @@ -574,43 +572,31 @@ py::object RunOp(const py::args &args) { CheckPyNativeContext(); auto executor = PynativeExecutor::GetInstance(); MS_EXCEPTION_IF_NULL(executor); - OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args); + OpExecInfoPtr op_exec_info = executor->forward_executor()->GenerateOpExecInfo(args); MS_EXCEPTION_IF_NULL(op_exec_info); MS_LOG(DEBUG) << "RunOp name: " << op_exec_info->op_name << " start, args: " << args.size(); - try { - return executor->RunOpInner(op_exec_info); - } catch (const py::error_already_set &ex) { - executor->ClearRes(); - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - executor->ClearRes(); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - executor->ClearRes(); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - executor->ClearRes(); - throw py::index_error(ex); - } catch (const std::exception &ex) { - executor->ClearRes(); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - executor->ClearRes(); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; - } + py::object ret = py::none(); + PynativeExecutorTry(executor->forward_executor()->RunOpS, &ret, op_exec_info); + return ret; +} + +GradExecutorPtr ForwardExecutor::grad() const { + auto grad_executor = grad_executor_.lock(); + MS_EXCEPTION_IF_NULL(grad_executor); + return grad_executor; } -py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { +void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info) { + MS_EXCEPTION_IF_NULL(ret); MS_EXCEPTION_IF_NULL(op_exec_info); if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { - py::tuple ret = RunOpWithInitBackendPolicy(op_exec_info); - if (ret.size() == 1) { - return ret[0]; + py::tuple res = RunOpWithInitBackendPolicy(op_exec_info); + if (res.size() == 1) { + *ret = res[0]; + return; } - return std::move(ret); + *ret = std::move(res); + return; } // make cnode for building grad graph if grad flag is set. abstract::AbstractBasePtrList args_spec_list; @@ -626,10 +612,12 @@ py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { MS_EXCEPTION_IF_NULL(prim); py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); if (!output["value"].is_none()) { - return output["value"]; + *ret = output["value"]; + return; } if (prim->is_const_prim()) { - return py::cast(""); + *ret = py::cast(""); + return; } // add output abstract info into cache if (!is_find && !op_exec_info->is_dynamic_shape) { @@ -654,15 +642,17 @@ py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { } std::string obj_id = GetId(out_real); node_abs_map_[obj_id] = op_exec_info->abstract; - // save info for building grad graph - SaveOutputNodeMap(obj_id, out_real, cnode); - SaveAllResult(op_exec_info, cnode, out_real); - // Update the abstract and device address of value node with tensor in grad graph - UpdateAbstractAndDeviceAddress(op_exec_info, out_real); - return out_real; + // Save info for building grad graph + if (grad()->grad_flag() && grad()->in_grad_process()) { + grad()->SaveOutputNodeMap(obj_id, out_real, cnode); + grad()->SaveAllResult(op_exec_info, cnode, out_real); + // Update the abstract and device address of value node with tensor in grad graph + UpdateAbstractAndDeviceAddress(op_exec_info, out_real); + } + *ret = out_real; } -OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { +OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) { if (args.size() != PY_ARGS_NUM) { MS_LOG(ERROR) << "Three args are needed by RunOp"; return nullptr; @@ -670,13 +660,15 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { auto op_exec_info = std::make_shared(); auto op_name = py::cast(args[PY_NAME]); op_exec_info->op_name = op_name; - if (grad_flag()) { - op_exec_info->op_index = op_name + "_" + std::to_string(op_index_map_[op_name]); - if (!cell_op_info_stack_.empty()) { - std::string &cell_op_info = cell_op_info_stack_.top(); + // Need const grad graph + if (grad()->grad_flag()) { + // Get forward op index + op_exec_info->op_index = op_name + "_" + std::to_string(grad()->op_index_map()[op_name]); + if (!grad()->cell_op_info_stack().empty()) { + std::string &cell_op_info = grad()->cell_op_info_stack().top(); cell_op_info += op_exec_info->op_index; } - op_index_map_[op_name]++; + grad()->op_index_map()[op_name]++; } auto prim = py::cast(args[PY_PRIM]); MS_EXCEPTION_IF_NULL(prim); @@ -689,8 +681,8 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { return op_exec_info; } -void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, - std::vector *inputs, abstract::AbstractBasePtrList *args_spec_list) { +void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, + std::vector *inputs, abstract::AbstractBasePtrList *args_spec_list) { auto prim = op_exec_info->py_primitive; for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) { abstract::AbstractBasePtr abs = nullptr; @@ -710,10 +702,11 @@ void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vecto MS_LOG(DEBUG) << "Gen args i " << i << " op_mask " << op_mask; (*op_masks).emplace_back(op_mask); - if (need_construct_graph()) { + // Construct grad graph + if (grad()->need_construct_graph()) { AnfNodePtr input_node = nullptr; - if (!graph_info_map_.empty() && !top_cell_list_.empty()) { - input_node = GetInput(obj, op_mask); + if (!grad()->top_cell_list().empty()) { + input_node = grad()->GetInput(obj, op_mask); } // update abstract if (input_node != nullptr) { @@ -727,8 +720,8 @@ void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vecto } } -AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, - abstract::AbstractBasePtrList *args_spec_list) { +AnfNodePtr ForwardExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, + abstract::AbstractBasePtrList *args_spec_list) { MS_EXCEPTION_IF_NULL(op_masks); MS_EXCEPTION_IF_NULL(args_spec_list); MS_EXCEPTION_IF_NULL(op_exec_info); @@ -754,21 +747,20 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v if (op_exec_info->op_name != prim::kPrimCast->name()) { RunParameterAutoMixPrecisionCast(op_exec_info); } - MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag(); + MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad()->grad_flag(); GetArgsSpec(op_exec_info, op_masks, &inputs, args_spec_list); CNodePtr cnode = nullptr; - if (need_construct_graph()) { - MS_EXCEPTION_IF_NULL(curr_g_); - cnode = curr_g_->NewCNodeInOrder(inputs); + if (grad()->need_construct_graph()) { + cnode = grad()->curr_g()->NewCNodeInOrder(inputs); MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << " new cnode is " << cnode->DebugString(4); } return cnode; } -abstract::AbstractBasePtr PynativeExecutor::CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, - const abstract::AbstractBasePtr &abs, const std::string &id, - size_t index) { +abstract::AbstractBasePtr ForwardExecutor::CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, + const abstract::AbstractBasePtr &abs, const std::string &id, + size_t index) { MS_EXCEPTION_IF_NULL(prim); auto const_input_index = prim->get_const_input_indexes(); bool have_const_input = !const_input_index.empty(); @@ -794,8 +786,8 @@ abstract::AbstractBasePtr PynativeExecutor::CheckConstValue(const PrimitivePyPtr return new_abs; } -void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, - const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) { +void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, + const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) { MS_EXCEPTION_IF_NULL(is_find); MS_EXCEPTION_IF_NULL(op_exec_info); *is_find = false; @@ -829,8 +821,8 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, } } -py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, - size_t index) { +py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, + size_t index) { py::tuple cast_args(3); cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast"); cast_args[PY_NAME] = prim::kPrimCast->name(); @@ -844,11 +836,13 @@ py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &typ op_exec->is_mixed_precision_cast = true; op_exec->next_op_name = op_name; op_exec->next_input_index = index; - return RunOpInner(op_exec); + py::object ret = py::none(); + RunOpInner(&ret, op_exec); + return ret; } -py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, - size_t index) { +py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, + size_t index) { MS_EXCEPTION_IF_NULL(is_cast); auto tensor = py::cast(obj); auto cast_type = tensor->cast_dtype(); @@ -864,8 +858,8 @@ py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::ob return cast_output; } -py::object PynativeExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple, - const std::string &op_name, size_t index) { +py::object ForwardExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, + const std::string &op_name, size_t index) { MS_EXCEPTION_IF_NULL(is_cast); auto tuple_size = static_cast(tuple.size()); py::tuple result(tuple_size); @@ -883,9 +877,9 @@ py::object PynativeExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const p return std::move(result); } -void PynativeExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map &dst_type, - const std::vector &dtypes, - const OpExecInfoPtr &op_exec_info) { +void ForwardExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map &dst_type, + const std::vector &dtypes, + const OpExecInfoPtr &op_exec_info) { const auto &signature = prim->signatures(); auto &out_args = op_exec_info->op_inputs; for (size_t i = 0; i < out_args.size(); ++i) { @@ -944,7 +938,7 @@ void PynativeExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::ma } } -void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) { +void ForwardExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) { size_t size = op_exec_info->op_inputs.size(); auto prim = op_exec_info->py_primitive; MS_EXCEPTION_IF_NULL(prim); @@ -990,7 +984,7 @@ void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_ DoSignatrueCast(prim, dst_types, dtypes, op_exec_info); } -AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { +AnfNodePtr GradExecutor::GetInput(const py::object &obj, bool op_mask) { AnfNodePtr node = nullptr; std::string obj_id = GetId(obj); @@ -1002,9 +996,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } auto param_name = py::cast(name_attr); - auto df_builder = GetDfbuilder(top_cell_id_); + auto df_builder = GetDfbuilder(top_cell_id()); MS_EXCEPTION_IF_NULL(df_builder); - auto graph_info = graph_info_map_.at(df_builder); + auto graph_info = top_cell()->graph_info_map().at(df_builder); MS_EXCEPTION_IF_NULL(graph_info); if (graph_info->params.find(obj_id) == graph_info->params.end()) { auto free_param = df_builder->add_parameter(); @@ -1025,7 +1019,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { return node; } - auto graph_info = graph_info_map_.at(curr_g_); + auto graph_info = top_cell()->graph_info_map().at(curr_g_); MS_EXCEPTION_IF_NULL(graph_info); if (graph_info->node_map.find(obj_id) != graph_info->node_map.end()) { // op(x, y) @@ -1058,31 +1052,30 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { return node; } -void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) { +void ForwardExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) { MS_EXCEPTION_IF_NULL(op_exec_info); - if (!grad_flag()) { - return; - } auto op_index = op_exec_info->op_index; auto output_value = PyAttrValue(out_real); MS_EXCEPTION_IF_NULL(output_value); std::vector output_tensors; TensorValueToTensor(output_value, &output_tensors); - if (cell_op_index_with_tensor_id_[top_cell_id_].find(op_index) == cell_op_index_with_tensor_id_[top_cell_id_].end()) { + if (cell_op_index_with_tensor_id()[grad()->top_cell_id()].find(op_index) == + cell_op_index_with_tensor_id()[grad()->top_cell_id()].end()) { // first step std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) { - cell_op_index_with_tensor_id_[top_cell_id_][op_index].emplace_back(tensor->id()); + cell_op_index_with_tensor_id()[grad()->top_cell_id()][op_index].emplace_back(tensor->id()); }); return; } auto ms_context = MsContext::GetInstance(); auto target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - const auto &tensor_id_list = cell_op_index_with_tensor_id_[top_cell_id_][op_index]; + const auto &tensor_id_list = cell_op_index_with_tensor_id()[grad()->top_cell_id()][op_index]; for (size_t i = 0; i < tensor_id_list.size(); ++i) { auto tensor_id = tensor_id_list[i]; - if (cell_tensor_id_with_tensor_[top_cell_id_].find(tensor_id) != cell_tensor_id_with_tensor_[top_cell_id_].end()) { + if (cell_tensor_id_with_tensor()[grad()->top_cell_id()].find(tensor_id) != + cell_tensor_id_with_tensor()[grad()->top_cell_id()].end()) { auto &new_tensor = output_tensors[i]; - auto &tensors_in_value_node = cell_tensor_id_with_tensor_[top_cell_id_][tensor_id]; + auto &tensors_in_value_node = cell_tensor_id_with_tensor()[grad()->top_cell_id()][tensor_id]; std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id " << tensor->id() << ", device address " << tensor->device_address().get() @@ -1109,17 +1102,20 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex } } -void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { +void GradExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { MS_EXCEPTION_IF_NULL(resource); std::set forward_op_tensor_id; - for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) { - const auto &tensor_id_list = elem.second; - for (const auto &tensor_id : tensor_id_list) { - forward_op_tensor_id.emplace(tensor_id); + auto it = forward()->cell_op_index_with_tensor_id().find(top_cell_id()); + if (it != forward()->cell_op_index_with_tensor_id().end()) { + for (const auto &elem : it->second) { + const auto &tensor_id_list = elem.second; + for (const auto &tensor_id : tensor_id_list) { + forward_op_tensor_id.emplace(tensor_id); + } } } - cell_tensor_id_with_tensor_[top_cell_id_].clear(); + forward()->cell_tensor_id_with_tensor()[top_cell_id()].clear(); const auto &func_graph = resource->func_graph(); const auto &value_node_list = func_graph->value_nodes(); for (const auto &elem : value_node_list) { @@ -1130,7 +1126,7 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { for (const auto &tensor : tensors) { if (tensor->device_address() != nullptr && forward_op_tensor_id.find(tensor->id()) != forward_op_tensor_id.end()) { - cell_tensor_id_with_tensor_[top_cell_id_][tensor->id()].emplace_back(tensor); + forward()->cell_tensor_id_with_tensor()[top_cell_id()][tensor->id()].emplace_back(tensor); MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id() << ", device address " << tensor->device_address().get(); } @@ -1138,15 +1134,15 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { } } -void PynativeExecutor::CleanPreMemoryInValueNode() { +void GradExecutor::CleanPreMemoryInValueNode() { auto ms_context = MsContext::GetInstance(); std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - if (device_target == "CPU") { + if (device_target == "CPU" || pre_top_cell_ == nullptr) { return; } - if (has_dynamic_cell_) { + if (pre_top_cell_->has_dynamic_cell()) { std::set forward_op_tensor_id; - for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) { + for (const auto &elem : forward()->cell_op_index_with_tensor_id().at(pre_top_cell_->cell_id())) { const auto &tensor_id_list = elem.second; for (const auto &tensor_id : tensor_id_list) { forward_op_tensor_id.emplace(tensor_id); @@ -1161,7 +1157,12 @@ void PynativeExecutor::CleanPreMemoryInValueNode() { } all_value_node_tensors_.clear(); } - const auto &tensor_id_with_tensor = cell_tensor_id_with_tensor_[top_cell_id_]; + auto it = forward()->cell_tensor_id_with_tensor().find(pre_top_cell_->cell_id()); + if (it == forward()->cell_tensor_id_with_tensor().end()) { + pre_top_cell_ = nullptr; + return; + } + const auto &tensor_id_with_tensor = it->second; for (const auto &elem : tensor_id_with_tensor) { const auto &tensors_in_value_node = elem.second; for (const auto &tensor : tensors_in_value_node) { @@ -1169,10 +1170,11 @@ void PynativeExecutor::CleanPreMemoryInValueNode() { tensor->set_device_address(nullptr); } } + pre_top_cell_ = nullptr; } -AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { - auto graph_info = graph_info_map_.at(curr_g_); +AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { + auto graph_info = top_cell()->graph_info_map().at(curr_g_); MS_EXCEPTION_IF_NULL(graph_info); auto &out = graph_info->node_map.at(obj_id); if (out.second.size() == 1 && out.second[0] == -1) { @@ -1216,13 +1218,13 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string } } if (node->abstract() != nullptr) { - node_abs_map_[obj_id] = node->abstract(); + forward()->node_abs_map()[obj_id] = node->abstract(); } MS_LOG(DEBUG) << "GetObjNode output " << node->DebugString(6); return node; } -AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) { +AnfNodePtr GradExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) { ValuePtr converted_ret = nullptr; parse::ConvertData(obj, &converted_ret); auto node = NewValueNode(converted_ret); @@ -1230,14 +1232,13 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str return node; } -void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, - const AnfNodePtr &cnode) { - if (!need_construct_graph()) { +void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode) { + if (graph_stack_.empty()) { MS_LOG(DEBUG) << "No need save output"; return; } + MS_EXCEPTION_IF_NULL(cnode); MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString(4) << " id " << obj_id; - if (py::isinstance(out_real)) { auto value = py::cast(out_real); auto size = static_cast(value.size()); @@ -1252,9 +1253,9 @@ void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::ob SetPyObjInGraphInfoMap(curr_g_, obj_id); } -void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, - const py::object &out_real) { - if (!grad_flag() || node == nullptr) { +void GradExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, + const py::object &out_real) { + if (node == nullptr) { return; } @@ -1266,8 +1267,9 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const An for (size_t i = 0; i < size; i++) { auto obj = op_exec_info->op_inputs[i]; auto obj_id = GetId(obj); - if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) { - cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]); + auto it = obj_to_forward_id_.find(obj_id); + if (it != obj_to_forward_id_.end()) { + cnode->add_input_value(PyAttrValue(obj), it->second); } else { cnode->add_input_value(nullptr, ""); } @@ -1287,43 +1289,7 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const An obj_to_forward_id_[out_id] = op_exec_info->op_index; } -void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map *t_map) { - if (t_map == nullptr) { - return; - } - for (size_t i = 0; i < tuple->size(); i++) { - ValuePtr tuple_i = (*tuple)[i]; - if (tuple_i->isa()) { - auto t = tuple_i->cast(); - (*t_map)[t->id()] = t; - } else if (tuple_i->isa()) { - GenTupleMap(tuple_i->cast(), t_map); - } - } - MS_LOG(DEBUG) << "End GenTupleMap " << tuple->ToString(); -} - -ValuePtr PynativeExecutor::CleanTupleAddr(const ValueTuplePtr &tuple) { - std::vector value_list; - for (size_t i = 0; i < tuple->size(); i++) { - ValuePtr tuple_i = (*tuple)[i]; - if (tuple_i->isa()) { - auto t = tuple_i->cast(); - auto new_tensor = std::make_shared(*t); - new_tensor->set_device_address(nullptr); - value_list.emplace_back(new_tensor); - } else if (tuple_i->isa()) { - value_list.emplace_back(CleanTupleAddr(tuple_i->cast())); - } else { - MS_LOG(DEBUG) << "Tuple[i] value " << tuple_i->ToString(); - value_list.emplace_back(tuple_i); - } - } - MS_LOG(DEBUG) << "End CleanTupleAddr"; - return std::make_shared(value_list); -} - -py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) { +py::tuple ForwardExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) { auto backend_policy = InitEnv(op_exec_info); PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; // returns a null py::tuple on error @@ -1331,12 +1297,11 @@ py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_e if (status != PYNATIVE_SUCCESS) { MS_LOG(EXCEPTION) << "Failed to run " << op_exec_info->op_name; } - MS_LOG(DEBUG) << "RunOp end"; return result; } -MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) { +MsBackendPolicy ForwardExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) { MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; parse::python_adapter::set_python_env_flag(true); MsBackendPolicy backend_policy; @@ -1365,8 +1330,8 @@ MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) { return backend_policy; } -py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, - PynativeStatusCode *const status) { +py::object ForwardExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, + PynativeStatusCode *status) { MS_EXCEPTION_IF_NULL(status); py::object result; switch (backend_policy) { @@ -1402,7 +1367,7 @@ py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_poli return result; } -py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { +py::object ForwardExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_LOG(INFO) << "RunOpInVM start"; MS_EXCEPTION_IF_NULL(status); MS_EXCEPTION_IF_NULL(op_exec_info); @@ -1417,7 +1382,7 @@ py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, Pynati auto input_obj_id = GetId(input); auto tensor = py::cast(input); MS_EXCEPTION_IF_NULL(tensor); - if (obj_to_forward_id_.find(input_obj_id) == obj_to_forward_id_.end() && + if (grad()->obj_to_forward_id().find(input_obj_id) == grad()->obj_to_forward_id().end() && op_exec_info->op_name == "HookBackward") { // the input object is not a output of forward cnode, eg: parameter result[i] = tensor; @@ -1452,7 +1417,7 @@ py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, Pynati return std::move(tuple_result); } -py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { +py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(status); MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms"; @@ -1501,14 +1466,43 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati return result; } -void PynativeExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); } +void ForwardExecutor::ClearRes() { + MS_LOG(DEBUG) << "Clear forward res"; + prim_abs_list_.clear(); + node_abs_map_.clear(); + cell_op_index_with_tensor_id_.clear(); + cell_tensor_id_with_tensor_.clear(); +} + +ForwardExecutorPtr GradExecutor::forward() const { + auto forward_executor = forward_executor_.lock(); + MS_EXCEPTION_IF_NULL(forward_executor); + return forward_executor; +} + +DynamicAnalysisPtr GradExecutor::dynamic_analysis() const { + MS_EXCEPTION_IF_NULL(dynamic_analysis_); + return dynamic_analysis_; +} + +TopCellInfoPtr GradExecutor::top_cell() const { + MS_EXCEPTION_IF_NULL(top_cell_); + return top_cell_; +} -void PynativeExecutor::PushCurrentCellOpInfoToStack() { +FuncGraphPtr GradExecutor::curr_g() const { + MS_EXCEPTION_IF_NULL(curr_g_); + return curr_g_; +} + +void GradExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); } + +void GradExecutor::PushCurrentCellOpInfoToStack() { std::string cell_op_info = "Cell ops: "; cell_op_info_stack_.push(cell_op_info); } -void PynativeExecutor::PopGraphStack() { +void GradExecutor::PopGraphStack() { if (graph_stack_.empty()) { MS_LOG(EXCEPTION) << "Stack graph_stack_ is empty"; } @@ -1518,38 +1512,43 @@ void PynativeExecutor::PopGraphStack() { } } -void PynativeExecutor::PopCurrentCellOpInfoFromStack() { +void GradExecutor::PopCurrentCellOpInfoFromStack() { if (cell_op_info_stack_.empty()) { MS_LOG(EXCEPTION) << "The cell op info stack is empty"; } cell_op_info_stack_.pop(); } -std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) { +std::string GradExecutor::GetCellId(const py::object &cell, const py::args &args) { auto cell_id = GetId(cell); for (size_t i = 0; i < args.size(); i++) { std::string arg_id = GetId(args[i]); - auto it = node_abs_map_.find(arg_id); - if (it != node_abs_map_.end()) { + auto it = forward()->node_abs_map().find(arg_id); + if (it != forward()->node_abs_map().end()) { cell_id += "_" + it->second->BuildShape()->ToString(); cell_id += it->second->BuildType()->ToString(); } else { auto abs = PyAttrValue(args[i])->ToAbstract(); auto config = abstract::AbstractBase::kBroadenTensorOnly; abs = abs->Broaden(config); - node_abs_map_[arg_id] = abs; + forward()->node_abs_map()[arg_id] = abs; cell_id += "_" + abs->BuildShape()->ToString(); cell_id += abs->BuildType()->ToString(); } } - return GetTensorCellId(cell_id); + return cell_id; } -std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) { - if (cell_id.find("NoShape") == std::string::npos) { - return cell_id; +void GradExecutor::SetTopCellTensorId(const std::string &cell_id) { + // Get top cell id + if (top_cell()->cell_graph_list().empty()) { + return; + } + auto top_cell_id = top_cell()->cell_graph_list().front()->cell_id(); + if (top_cell_id.find("NoShape") == std::string::npos) { + return; } - std::string key = cell_id.substr(0, PTR_LEN); + std::string key = top_cell_id.substr(0, PTR_LEN); auto fn = [](const std::string &str, std::vector &value) { size_t pos = 0; size_t pre_pos = 0; @@ -1559,59 +1558,61 @@ std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) { } value.emplace_back(str.substr(pre_pos)); }; - auto it = - std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&key](const CellInfoPtr &value) { - return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos; - }); - if (it != cell_graph_list_.end()) { - std::vector pre_cell_id; - std::vector cur_cell_id; - fn((*it)->cell_id, pre_cell_id); - fn(cell_id, cur_cell_id); - auto pre_tensor_size = pre_cell_id.size(); - if (pre_tensor_size == cur_cell_id.size()) { - size_t same_tensor_count = 0; - for (size_t i = 0; i < pre_tensor_size; ++i) { - if (cur_cell_id[i].find("NoShape") != std::string::npos || cur_cell_id[i] == pre_cell_id[i]) { - ++same_tensor_count; - } - } - if (same_tensor_count == pre_tensor_size) { - MS_LOG(DEBUG) << "Changed cell id from " << cell_id << " to " << (*it)->cell_id; - return (*it)->cell_id; + std::vector pre_cell_id; + std::vector cur_cell_id; + fn(cell_id, cur_cell_id); + fn(top_cell_id, pre_cell_id); + auto pre_tensor_size = pre_cell_id.size(); + if (pre_tensor_size == cur_cell_id.size()) { + size_t same_tensor_count = 0; + for (size_t i = 0; i < pre_tensor_size; ++i) { + if (pre_cell_id[i].find("NoShape") != std::string::npos || cur_cell_id[i] == pre_cell_id[i]) { + ++same_tensor_count; } } + if (same_tensor_count == pre_tensor_size) { + MS_LOG(DEBUG) << "Changed cell id from " << top_cell_id << " to " << cell_id; + top_cell()->cell_graph_list().front()->set_cell_id(cell_id); + } } - return cell_id; } -void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) { +void GradExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) { if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { DumpIR(filename, graph); } } -bool PynativeExecutor::IsNestedGrad() const { +bool GradExecutor::IsNestedGrad() const { MS_LOG(DEBUG) << "Grad nested order is " << grad_order_; return grad_order_ > 1; } -bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { - return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); +bool GradExecutor::IsTopGraph(const std::string &cell_id) { + return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { + return value->cell_id().find(cell_id) != std::string::npos; + }); } -bool PynativeExecutor::IsTopestGraph(const std::string &cell_id) { - return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; }); +bool GradExecutor::IsTopestGraph(const std::string &cell_id) { + return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { + return value->cell_id() == cell_id && value->is_topest(); + }); +} + +bool GradExecutor::TopCellIsDynamic() { + if (top_cell_ == nullptr) { + return false; + } + return CheckRealDynamicCell(top_cell_id()); } -TopCellInfoPtr PynativeExecutor::GetTopCell(const string &cell_id, bool find_nearest) { +TopCellInfoPtr GradExecutor::GetTopCell(const string &cell_id, bool find_nearest) { auto find_top_cell = [&](const string &cell_id) -> TopCellInfoPtr { - auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &top_cell) { - return cell_id == top_cell->cell_id && top_cell->is_topest; - }); - if (iter != top_cell_list_.end()) { + auto iter = std::find_if( + top_cell_list_.rbegin(), top_cell_list_.rend(), + [&cell_id](const TopCellInfoPtr &top_cell) { return cell_id == top_cell->cell_id() && top_cell->is_topest(); }); + if (iter != top_cell_list_.rend()) { return *iter; } return nullptr; @@ -1619,114 +1620,171 @@ TopCellInfoPtr PynativeExecutor::GetTopCell(const string &cell_id, bool find_nea TopCellInfoPtr top_cell = find_top_cell(cell_id); // find nearest top cell if (top_cell == nullptr && find_nearest) { - for (const auto &cell_info : cell_graph_list_) { - MS_EXCEPTION_IF_NULL(cell_info); - top_cell = find_top_cell(cell_info->cell_id); - if (cell_id == cell_info->cell_id) { - break; + for (auto it = top_cell_list_.begin(); it != top_cell_list_.end(); ++it) { + MS_EXCEPTION_IF_NULL(*it); + for (const auto &cell_info : (*it)->cell_graph_list()) { + MS_EXCEPTION_IF_NULL(cell_info); + if (cell_id == cell_info->cell_id()) { + return *it; + } } } } return top_cell; } -void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled) { +void GradExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled) { auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id; }); if (it != top_cell_list_.end()) { - (*it)->do_vm_compiled = vm_compiled; - (*it)->forward_already_run = false; - (*it)->need_grad = true; - if ((*it)->is_topest) { + (*it)->set_vm_compiled(vm_compiled); + (*it)->set_forward_already_run(false); + (*it)->set_need_grad(true); + (*it)->set_is_grad(true); + if ((*it)->is_topest()) { in_grad_process_ = false; - top_cell_index_ = 0; } } } -bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { - return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos; - }); +bool GradExecutor::IsBpropGraph(const std::string &cell_id) { + if (top_cell_ == nullptr) { + return false; + } + return std::any_of( + top_cell_->cell_graph_list().begin(), top_cell_->cell_graph_list().end(), [&cell_id](const CellInfoPtr &value) { + return !value->bprop_cell_id().empty() && cell_id.find(value->bprop_cell_id()) != std::string::npos; + }); } -bool PynativeExecutor::IsFirstGradStep(const std::string &cell_id) { return !CheckCellGraph(cell_id, true); } +bool GradExecutor::IsFirstGradStep() { return !top_cell()->is_grad(); } + +bool GradExecutor::IsGradBefore(const std::string &cell_id) { + return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { + return value->cell_id() == cell_id && value->is_grad(); + }); +} -void PynativeExecutor::SubNestedGradOrder() { +void GradExecutor::SubNestedGradOrder() { if (grad_order_ > 0) { --grad_order_; } } -bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) { - return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id, is_grad](const CellInfoPtr &value) { - return value->cell_id == cell_id && (!is_grad || value->is_grad); - }); +bool GradExecutor::CheckCellGraph(const std::string &cell_id) { + if (top_cell_ == nullptr) { + for (const auto &it : top_cell_list_) { + MS_EXCEPTION_IF_NULL(it); + if (it->cell_id() == cell_id) { + set_top_cell(it); + return true; + } + } + return false; + } else { + return std::any_of(top_cell_->cell_graph_list().begin(), top_cell_->cell_graph_list().end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); + } } -bool PynativeExecutor::CheckDynamicCell(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_dynamic; }); +bool GradExecutor::CheckDynamicCell(const std::string &cell_id) { + if (top_cell_ == nullptr) { + return false; + } + return std::any_of( + top_cell_->cell_graph_list().begin(), top_cell_->cell_graph_list().end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id && value->is_dynamic(); }); } -bool PynativeExecutor::CheckRealDynamicCell(const std::string &cell_id) { - return std::any_of( - cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_real_dynamic; }); +bool GradExecutor::CheckRealDynamicCell(const std::string &cell_id) { + if (top_cell_ == nullptr) { + return false; + } + return top_cell_->is_real_dynamic(); } -void PynativeExecutor::ClearResidualRes(const std::string &cell_id) { +void GradExecutor::ClearResidualRes(const std::string &cell_id) { // Abnormal case if (top_cell_list_.empty() && !graph_stack_.empty()) { - graph_id_ = 0; - graph_info_map_.clear(); - cell_graph_list_.clear(); + ClearCellRes(); std::stack().swap(graph_stack_); } - if (CheckRealDynamicCell(cell_id)) { - if (IsTopGraph(cell_id) && graph_stack_.empty() && !IsBpropGraph(cell_id)) { - // Clear previous step resource - auto resource = GetResource(cell_id); - if (resource != nullptr && resource->results().find(pipeline::kBackend) != resource->results().end()) { - compile::BackendPtr backend = resource->results()[pipeline::kBackend].cast(); - auto ms_backend = std::dynamic_pointer_cast(backend); - ms_backend->ClearSessionGraphs(); + if (pre_top_cell_ == nullptr || !graph_stack_.empty() || !IsTopGraph(cell_id) || IsBpropGraph(cell_id)) { + return; + } + auto is_real_dynamic = pre_top_cell_->is_real_dynamic(); + if (is_real_dynamic) { + // Clear previous step resource + auto resource = GetResource(cell_id); + if (resource != nullptr && resource->results().find(pipeline::kBackend) != resource->results().end()) { + compile::BackendPtr backend = resource->results()[pipeline::kBackend].cast(); + auto ms_backend = std::dynamic_pointer_cast(backend); + ms_backend->ClearSessionGraphs(); + } + } +} + +void GradExecutor::ClearCellRes(const std::string &cell_id) { + // Grad clean + if (cell_id.empty()) { + for (const auto &it : top_cell_list_) { + it->clear(); + } + return; + } + if (IsTopGraph(cell_id)) { + for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) { + if ((*it)->cell_id().find(cell_id) != std::string::npos) { + (*it)->clear(); + it = top_cell_list_.erase(it); + } else { + it++; + } + } + } else { + // Clear common cell id + for (const auto &it : top_cell_list_) { + MS_EXCEPTION_IF_NULL(it); + for (auto ic = it->cell_graph_list().begin(); ic != it->cell_graph_list().end();) { + if ((*ic)->cell_id().find(cell_id) != std::string::npos) { + ic = it->cell_graph_list().erase(ic); + } else { + ++ic; + } } } } } -FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { +FuncGraphPtr GradExecutor::GetDfbuilder(const std::string &cell_id) { // If top graph hold for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) { - if (cell_id.find((*it)->cell_id) != std::string::npos) { - return (*it)->df_builder; + if (cell_id.find((*it)->cell_id()) != std::string::npos) { + return (*it)->df_builder(); } } // Current cell is not top graph, get first top cell if (!top_cell_list_.empty()) { - return top_cell_list_.front()->df_builder; + return top_cell_list_.front()->df_builder(); } return nullptr; } -ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) { +ResourcePtr GradExecutor::GetResource(const std::string &cell_id) { for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) { - if (cell_id.find((*it)->cell_id) != std::string::npos) { - return (*it)->resource; + if (cell_id.find((*it)->cell_id()) != std::string::npos) { + return (*it)->resource(); } } // Current cell is not top graph, get first top cell if (!top_cell_list_.empty()) { - return top_cell_list_.front()->resource; + return top_cell_list_.front()->resource(); } return nullptr; } -std::string PynativeExecutor::ParseNodeName(const std::shared_ptr &ast, const py::object &node, - parse::AstMainType type) { +std::string DynamicAnalysis::ParseNodeName(const std::shared_ptr &ast, const py::object &node, + parse::AstMainType type) { MS_EXCEPTION_IF_NULL(ast); if (py::isinstance(node)) { MS_LOG(DEBUG) << "Get none type node!"; @@ -1734,7 +1792,7 @@ std::string PynativeExecutor::ParseNodeName(const std::shared_ptrGetNodeType(node); MS_EXCEPTION_IF_NULL(node_type); - // check node type + // Check node type parse::AstMainType node_main_type = node_type->main_type(); if (node_main_type != type) { MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type; @@ -1745,7 +1803,7 @@ std::string PynativeExecutor::ParseNodeName(const std::shared_ptr &ast, const py::object &fn_node) { +void DynamicAnalysis::ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node) { MS_EXCEPTION_IF_NULL(ast); py::list args = ast->GetArgs(fn_node); for (size_t i = 1; i < args.size(); i++) { @@ -1755,7 +1813,7 @@ void PynativeExecutor::ParseInputArgs(const std::shared_ptr &as } } -bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node) { +bool DynamicAnalysis::ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse if/while expr"; py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR); @@ -1804,7 +1862,7 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node) { +bool DynamicAnalysis::ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse assign expr"; py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE); const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR); @@ -1832,8 +1890,8 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node, - const std::vector &compare_prim) { +bool DynamicAnalysis::ParseAugAssignExprNode(const std::shared_ptr &ast, const py::object &node, + const std::vector &compare_prim) { MS_LOG(DEBUG) << "Parse augassign expr"; bool ret = false; if (compare_prim.empty()) { @@ -1860,7 +1918,7 @@ bool PynativeExecutor::ParseAugAssignExprNode(const std::shared_ptr &ast, const py::object &node) { +bool DynamicAnalysis::ParseForExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse for expr"; py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); if (py::isinstance(body_node)) { @@ -1880,8 +1938,8 @@ bool PynativeExecutor::ParseForExprNode(const std::shared_ptr & return false; } -bool PynativeExecutor::ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, - const std::vector &compare_prim) { +bool DynamicAnalysis::ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, + const std::vector &compare_prim) { MS_EXCEPTION_IF_NULL(ast); py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY); if (py::isinstance(func_obj)) { @@ -1912,7 +1970,7 @@ bool PynativeExecutor::ParseBodyContext(const std::shared_ptr & return ret; } -std::string PynativeExecutor::GetCellInfo(const py::object &cell) { +std::string DynamicAnalysis::GetCellInfo(const py::object &cell) { if (py::isinstance(cell)) { auto c_cell = py::cast(cell); MS_EXCEPTION_IF_NULL(c_cell); @@ -1922,12 +1980,12 @@ std::string PynativeExecutor::GetCellInfo(const py::object &cell) { return ""; } -bool PynativeExecutor::IsDynamicCell(const py::object &cell) { +bool DynamicAnalysis::IsDynamicCell(const py::object &cell) { std::string cell_info = GetCellInfo(cell); if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) { return false; } - // using ast parse to check whether the construct of cell will be changed + // Using ast parse to check whether the construct of cell will be changed auto ast = std::make_shared(cell); bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); if (!success) { @@ -1944,7 +2002,7 @@ bool PynativeExecutor::IsDynamicCell(const py::object &cell) { return ret; } -void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { +void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) { auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; // check whether cell needed to construct grad graph @@ -1954,13 +2012,11 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg CleanPreMemoryInValueNode(); op_index_map_.clear(); in_grad_process_ = true; - in_bprop_process_ = false; auto top_cell = GetTopCell(cell_id, flag); MS_EXCEPTION_IF_NULL(top_cell); - top_cell_id_ = top_cell->cell_id; - top_cell_index_ = top_cell->top_cell_index; - top_cell->forward_already_run = true; - MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; + top_cell->set_forward_already_run(true); + set_top_cell(top_cell); + MS_LOG(DEBUG) << "Top cell id " << top_cell->cell_id(); }; if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) { init_fn(false); @@ -1977,6 +2033,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg ClearResidualRes(cell_id); if (graph_stack_.empty()) { if (IsBpropGraph(cell_id)) { + in_grad_process_ = true; in_bprop_process_ = true; } else { MakeNewTopGraph(cell_id, args); @@ -1984,9 +2041,9 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg } PushCurrentGraphToStack(); PushCurrentCellOpInfoToStack(); - if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) { + if (top_cell()->graph_info_map().find(curr_g_) == top_cell()->graph_info_map().end()) { auto graph_info = std::make_shared(cell_id); - graph_info_map_[curr_g_] = graph_info; + top_cell()->graph_info_map()[curr_g_] = graph_info; } for (size_t i = 0; i < args.size(); ++i) { auto param = args[i]; @@ -1996,22 +2053,22 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param); SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param); } - // Check whether the construct of cell will be changed + // Check whether the construct of cell is dynamic if (!has_dynamic_cell_) { - has_dynamic_cell_ = IsDynamicCell(cell); + has_dynamic_cell_ = dynamic_analysis()->IsDynamicCell(cell); MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_; if (has_dynamic_cell_ && IsBpropGraph(cell_id)) { - auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [this](const CellInfoPtr &value) { return value->cell_id == top_cell_id_; }); - while (it != cell_graph_list_.end()) { - (*it)->is_dynamic = true; + auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), + [this](const CellInfoPtr &value) { return value->cell_id() == top_cell_id(); }); + while (it != top_cell()->cell_graph_list().end()) { + (*it)->set_is_dynamic(true); ++it; } } } } -void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args) { +void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args) { for (const auto &arg : args) { if (py::isinstance(arg)) { auto tensor = arg.cast(); @@ -2020,59 +2077,65 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar } } } - // Clear resource in old top cell - if (CheckRealDynamicCell(cell_id)) { - VectorClear>(&top_cell_list_, cell_id); - } - CleanPreMemoryInValueNode(); - // Init resource for new top cell - if (!CheckCellGraph(cell_id)) { - has_dynamic_cell_ = false; - } - op_index_map_.clear(); - top_cell_id_ = cell_id; - in_grad_process_ = true; - // update forward already run flag with previous top cell + CleanPreMemoryInValueNode(); std::string input_args_id; for (size_t i = 0; i < args.size(); ++i) { input_args_id = input_args_id + GetId(args[i]) + "_"; } - auto pre_top_cell = GetTopCell(cell_id); - if (pre_top_cell != nullptr) { - pre_top_cell->forward_already_run = true; - pre_top_cell->input_args_id = input_args_id; + auto pre_dynamic_top_cell = GetTopCell(cell_id); + bool is_real_dynamic = false; + // Dynamic top cell is not nullptr + if (pre_dynamic_top_cell != nullptr) { + has_dynamic_cell_ = true; + // Clear top cell + if (pre_dynamic_top_cell->is_real_dynamic()) { + ClearCellRes(cell_id); + is_real_dynamic = true; + pre_dynamic_top_cell = nullptr; + } else { + pre_dynamic_top_cell->set_forward_already_run(true); + pre_dynamic_top_cell->set_input_args_id(input_args_id); + } + } else { + has_dynamic_cell_ = false; } + op_index_map_.clear(); + in_grad_process_ = true; + + // Init resource for new top cell auto df_builder = std::make_shared(); auto graph_info = std::make_shared(cell_id); - graph_info_map_[df_builder] = graph_info; auto resource = std::make_shared(); - resource->results()[pipeline::kPynativeGraphId] = graph_id_++; - auto top_cell_info = std::make_shared(true, resource, df_builder, cell_id); - top_cell_info->forward_already_run = true; - top_cell_info->input_args_id = input_args_id; - if (!IsTopestGraph(cell_id)) { - top_cell_info->top_cell_index = cell_graph_list_.size(); - top_cell_index_ = top_cell_info->top_cell_index; - } else { + auto new_top_cell = std::make_shared(true, resource, df_builder, cell_id); + new_top_cell->graph_info_map()[df_builder] = graph_info; + new_top_cell->set_forward_already_run(true); + new_top_cell->set_input_args_id(input_args_id); + if (pre_dynamic_top_cell != nullptr) { MS_LOG(DEBUG) << "Get dynamic top cell"; - auto top_cell = GetTopCell(cell_id, true); - MS_EXCEPTION_IF_NULL(top_cell); - top_cell_info->top_cell_index = top_cell->top_cell_index; - top_cell_index_ = top_cell_info->top_cell_index; + if (pre_dynamic_top_cell->is_grad()) { + new_top_cell->set_is_grad(true); + } + new_top_cell->set_cell_graph_list(pre_dynamic_top_cell->cell_graph_list()); + new_top_cell->set_graph_info_map(pre_dynamic_top_cell->graph_info_map()); } - top_cell_list_.emplace_back(top_cell_info); + if (is_real_dynamic) { + MS_LOG(DEBUG) << "Get real dynamic"; + new_top_cell->set_is_real_dynamic(true); + } + set_top_cell(new_top_cell); + top_cell_list_.emplace_back(new_top_cell); MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); } -std::string PynativeExecutor::GetCellOpInfo() { +std::string GradExecutor::GetCellOpInfo() { if (cell_op_info_stack_.empty()) { MS_LOG(EXCEPTION) << "The cell op info stack is empty"; } return cell_op_info_stack_.top(); } -void PynativeExecutor::ReplaceCellOpInfoByCellId(const std::string &cell_id) { +void GradExecutor::ReplaceCellOpInfoByCellId(const std::string &cell_id) { if (cell_id.empty()) { MS_LOG(EXCEPTION) << "The cell id is empty"; } @@ -2083,8 +2146,8 @@ void PynativeExecutor::ReplaceCellOpInfoByCellId(const std::string &cell_id) { cell_op_info_stack_.top() = cell_op_info_stack_.top() + cell_id; } -void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, - bool is_param) { +void GradExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, + bool is_param) { if (!py::isinstance(args) && !py::isinstance(args)) { return; } @@ -2102,9 +2165,8 @@ void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const p } } -void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, - const AnfNodePtr &node, - const std::vector &index_sequence, bool is_param) { +void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, + const std::vector &index_sequence, bool is_param) { if (!py::isinstance(args) && !py::isinstance(args)) { return; } @@ -2124,7 +2186,7 @@ void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, con } } -void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { +void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args) { const auto &cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { @@ -2134,7 +2196,7 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o } auto out_id = GetId(out); // x =op1, y =op2, return (x, y) - auto graph_info = graph_info_map_.at(curr_g_); + auto graph_info = top_cell()->graph_info_map().at(curr_g_); MS_EXCEPTION_IF_NULL(graph_info); if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) { if (py::isinstance(out) || py::isinstance(out)) { @@ -2157,8 +2219,8 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o EndGraphByOutId(cell, cell_id, out, out_id, args); } -void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, - const std::string &out_id, const py::args &args) { +void GradExecutor::EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, + const std::string &out_id, const py::args &args) { AnfNodePtr output_node = GetObjNode(out, out_id); curr_g_->set_output(output_node); MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString(); @@ -2166,26 +2228,20 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string MS_LOG(DEBUG) << "Get bprop function cell"; return; } - auto resource = GetResource(top_cell_id_); + auto resource = GetResource(top_cell_id()); MS_EXCEPTION_IF_NULL(resource); resource->manager()->AddFuncGraph(curr_g_); UpdateCellGraph(cell, curr_g_, cell_id, true, false); FuncGraphPtr newfg = nullptr; - auto top_cell = GetTopCell(top_cell_id_); - MS_EXCEPTION_IF_NULL(top_cell); // Cell no Change if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) { - MS_LOG(DEBUG) << "Cell is not dynamic, No need make ad grad"; - top_cell->need_grad = false; - std::unordered_set node_set; - ClearCnodeRes(curr_g_->output(), &node_set); - node_set.clear(); + MS_LOG(DEBUG) << "Cell is fake dynamic, no need make ad grad"; + top_cell()->set_need_grad(false); + ClearCnodeRes(curr_g_->output()); } else { MS_LOG(DEBUG) << "Need make ad grad"; - if (!top_cell->need_grad) { - std::unordered_set node_set; - ClearCnodeRes(curr_g_->output(), &node_set); - node_set.clear(); + if (!top_cell()->need_grad()) { + ClearCnodeRes(curr_g_->output()); } newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); } @@ -2216,10 +2272,11 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string } PopGraphStack(); PopCurrentCellOpInfoFromStack(); + ClearDynamicTopRes(cell_id); } } -bool PynativeExecutor::EndBpropGraph(const string &cell_id) { +bool GradExecutor::EndBpropGraph(const string &cell_id) { auto is_bprop_graph = IsBpropGraph(cell_id); if (is_bprop_graph) { if (!IsNestedGrad()) { @@ -2232,9 +2289,9 @@ bool PynativeExecutor::EndBpropGraph(const string &cell_id) { return false; } -bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { +bool GradExecutor::CheckCellChanged(const std::string &cell_id) { bool res = false; - if (CheckRealDynamicCell(cell_id)) { + if (top_cell()->is_real_dynamic()) { MS_LOG(DEBUG) << "Cur cell " << cell_id << " is dynamic, no need check"; return true; } @@ -2242,75 +2299,71 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { MS_LOG(DEBUG) << "Cell op info is empty"; return true; } - auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); - if (it == cell_graph_list_.end() || IsFirstGradStep(top_cell_id_)) { + + auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); + if (it == top_cell()->cell_graph_list().end() || IsFirstGradStep()) { return true; } - MS_LOG(DEBUG) << "Cell op info " << GetCellOpInfo() << ", old " << (*it)->cell_ops_info.at((*it)->call_times); - if ((*it)->cell_ops_info.at((*it)->call_times) != GetCellOpInfo()) { + MS_LOG(DEBUG) << "Cell op info " << GetCellOpInfo() << ", old " << (*it)->cell_ops_info().at((*it)->call_times()); + if ((*it)->cell_ops_info().at((*it)->call_times()) != GetCellOpInfo()) { res = true; - UpdateCellDynamic(cell_id); + top_cell()->set_is_real_dynamic(true); MS_LOG(DEBUG) << "Cell self changed"; } - (*it)->call_times = (*it)->call_times < (*it)->cell_ops_info.size() - 1 ? (*it)->call_times + 1 : 0; + if ((*it)->call_times() < ((*it)->cell_ops_info().size() - 1)) { + (*it)->set_call_times((*it)->call_times() + 1); + } else { + (*it)->set_call_times(0); + } return res; } -void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) { - for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { - if ((*it)->cell_id != cell_id) { - (*it)->is_real_dynamic = true; - continue; - } - (*it)->is_real_dynamic = true; - break; +bool GradExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, + bool need_cloned, bool is_grad) { + if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { + return false; } -} - -bool PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, - bool need_cloned, bool is_grad) { auto update_in_endgraph = need_cloned && !is_grad; - if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { - // Bprop just save backward graph - auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); - if (it != cell_graph_list_.end()) { - (*it)->is_grad = is_grad; - if (g != (*it)->fg) { - graph_info_map_.update((*it)->fg, g); - (*it)->fg = g; - } - if (update_in_endgraph && IsFirstGradStep(top_cell_id_)) { - (*it)->cell_ops_info.emplace_back(GetCellOpInfo()); - } - MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id; + // Bprop just save backward graph + auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); + if (it != top_cell()->cell_graph_list().end()) { + if (top_cell_id() == cell_id) { + top_cell()->set_is_grad(is_grad); + } + if (g != (*it)->fg()) { + top_cell()->graph_info_map().update((*it)->fg(), g); + (*it)->set_fg(g); + } + if (update_in_endgraph && IsFirstGradStep()) { + (*it)->cell_ops_info().emplace_back(GetCellOpInfo()); + } + MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id; + } else { + py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); + auto bprop_func_cell_id = GetId(bprop_func); + MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id + << " cell ops info " << GetCellOpInfo(); + auto cell_info = std::make_shared(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id); + cell_info->cell_ops_info().emplace_back(GetCellOpInfo()); + if (in_bprop_process_) { + top_cell()->cell_graph_list().emplace_back(cell_info); } else { - py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); - auto bprop_func_cell_id = GetId(bprop_func); - MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id - << " cell ops info " << GetCellOpInfo(); - auto cell_info = std::make_shared(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id); - cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); - if (in_bprop_process_) { - cell_graph_list_.emplace_back(cell_info); - } else { - cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info); - } + top_cell()->cell_graph_list().insert(top_cell()->cell_graph_list().begin(), cell_info); } - return true; } - return false; + return true; } -void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, - bool need_cloned, bool is_grad) { +void GradExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, + bool need_cloned, bool is_grad) { auto update_in_endgraph = need_cloned && !is_grad; if (UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad)) { return; } FuncGraphPtr tmp = g; - if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) { + if (!IsFirstGradStep() && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) { MS_LOG(DEBUG) << "No need cloned"; need_cloned = false; } @@ -2319,77 +2372,77 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt return; } tmp = BasicClone(g); - graph_info_map_.update(g, tmp); - std::unordered_set node_set; - ClearCnodeRes(tmp->output(), &node_set); - node_set.clear(); + top_cell()->graph_info_map().update(g, tmp); + ClearCnodeRes(tmp->output()); }; // First call or cell id not exist - if (update_in_endgraph && (IsFirstGradStep(top_cell_id_) || !CheckCellGraph(cell_id))) { + if (update_in_endgraph && (IsFirstGradStep() || !CheckCellGraph(cell_id))) { if (!CheckCellGraph(cell_id)) { clone_fn(); MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo(); auto cell_info = std::make_shared(true, has_dynamic_cell_, tmp, cell_id, ""); - cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); + cell_info->cell_ops_info().emplace_back(GetCellOpInfo()); if (in_bprop_process_) { - cell_graph_list_.emplace_back(cell_info); + top_cell()->cell_graph_list().emplace_back(cell_info); } else { - cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info); + top_cell()->cell_graph_list().insert(top_cell()->cell_graph_list().begin(), cell_info); } } else { - auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); - if (it != cell_graph_list_.end()) { - (*it)->cell_ops_info.emplace_back(GetCellOpInfo()); + auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); + if (it != top_cell()->cell_graph_list().end()) { + (*it)->cell_ops_info().emplace_back(GetCellOpInfo()); } MS_LOG(DEBUG) << "Add another same cell ops info"; } return; } - for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { - if ((*it)->cell_id != cell_id) { + for (auto &it : top_cell()->cell_graph_list()) { + if (it->cell_id() != cell_id) { continue; } - if (IsFirstGradStep(cell_id)) { - // no compute grad - (*it)->is_grad = is_grad; - } + it->set_is_dynamic(has_dynamic_cell_); if (need_cloned) { clone_fn(); - if ((*it)->fg != nullptr) { - graph_info_map_.erase((*it)->fg); + if (it->fg() != nullptr) { + top_cell()->graph_info_map().erase(it->fg()); } - MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with cloned new " << tmp.get(); - (*it)->fg = tmp; + MS_LOG(DEBUG) << "Update cur graph " << it->fg().get() << " with cloned new " << tmp.get(); + it->set_fg(tmp); } if (!need_cloned && !is_grad) { - graph_info_map_.erase((*it)->fg); - MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with new " << tmp.get(); - (*it)->fg = tmp; + top_cell()->graph_info_map().erase(it->fg()); + MS_LOG(DEBUG) << "Update cur graph " << it->fg().get() << " with new " << tmp.get(); + it->set_fg(tmp); } break; } } -void PynativeExecutor::ClearCnodeRes(const AnfNodePtr &node, std::unordered_set *node_set) { +void GradExecutor::ClearCnodeRes(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node_set); - if (!node->isa() || (*node_set).find(node) != (*node_set).end()) { - return; - } - (*node_set).insert(node); - auto cnode = node->cast(); - cnode->clear_inputs_value(); - cnode->set_forward(nullptr, ""); - for (size_t i = 0; i < cnode->size(); ++i) { - auto n = cnode->input(i); - ClearCnodeRes(n, node_set); - } + std::unordered_set node_set; + std::function fn; + fn = [&fn, &node_set](const AnfNodePtr &node) { + if (!node->isa() || node_set.find(node) != node_set.end()) { + return; + } + node_set.insert(node); + auto cnode = node->cast(); + cnode->clear_inputs_value(); + cnode->set_forward(nullptr, ""); + for (size_t i = 0; i < cnode->size(); ++i) { + auto n = cnode->input(i); + fn(n); + } + }; + fn(node); + node_set.clear(); } -FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, - const std::string &cell_id, const py::args &args) { +FuncGraphPtr GradExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, + const std::string &cell_id, const py::args &args) { bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); if (is_custom_bprop) { size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size(); @@ -2406,6 +2459,7 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG } auto is_top = IsTopGraph(cell_id); MS_LOG(DEBUG) << "Grad top cell " << is_top; + DumpGraphIR("fg.ir", g); // Before make grad graph, we need to run auto-monad on forward graph, // so that side effects in forward graph can be handled in grad graph. (void)pipeline::AutoMonad(g); @@ -2430,8 +2484,8 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG return newfg; } -std::string PynativeExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, - py::object *forward_args, py::object *sens) { +std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, + py::object *forward_args, py::object *sens) { auto size = args.size(); size_t forward_args_size = size; if (has_sens) { @@ -2451,7 +2505,7 @@ std::string PynativeExecutor::GetGradCellId(bool has_sens, const py::object &cel return cell_id; } -void PynativeExecutor::SaveAllValueNodeTensors(const FuncGraphPtr &graph) { +void GradExecutor::SaveAllValueNodeTensors(const FuncGraphPtr &graph) { std::unordered_set all_value_node_tensors; auto trace_function = [&all_value_node_tensors](const AnfNodePtr &anf_node) { auto value = GetValueNode(anf_node); @@ -2482,17 +2536,18 @@ void PynativeExecutor::SaveAllValueNodeTensors(const FuncGraphPtr &graph) { all_value_node_tensors_ = all_value_node_tensors; } -void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args) { +void GradExecutor::GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, + const py::object &weights, const py::args &args) { auto size = args.size(); py::object sens = py::none(); py::object forward_args = args; const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args, &forward_args, &sens); + SetTopCellTensorId(cell_id); MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id; const auto ¶ms_changed = CheckGradParamsChanged(cell_id, weights, sens); - if (!params_changed && !IsFirstGradStep(cell_id) && !CheckRealDynamicCell(cell_id)) { + if (!params_changed && IsGradBefore(cell_id) && !CheckRealDynamicCell(cell_id)) { UpdateTopCellInfo(cell_id, false); - ClearDynamicTopRes(cell_id, nullptr); + op_index_map_.clear(); MS_LOG(INFO) << "Gradgraph already compiled"; return; } @@ -2518,7 +2573,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje resource->set_args_spec(args_spec); // Get real grad graph DumpGraphIR("before_grad.ir", resource->func_graph()); - GradGraph(resource->func_graph(), grad, w_args, size, cell_id); + SetGradGraph(resource->func_graph(), grad, w_args, size, cell_id); DumpGraphIR("after_grad.ir", df_builder); resource->set_func_graph(df_builder); resource->manager()->KeepRoots({df_builder}); @@ -2534,48 +2589,49 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje TaskEmitAction(resource); ExecuteAction(resource); ClearUselessRes(df_builder, cell, cell_id); - UpdateCellGraph(cell, curr_g_, cell_id, false, true); UpdateTopCellInfo(cell_id, true); - ClearDynamicTopRes(cell_id, df_builder); resource->Clean(); } -void PynativeExecutor::ClearDynamicTopRes(const std::string &cell_id, const FuncGraphPtr &df_builder) { - if (df_builder == nullptr && IsTopestGraph(cell_id)) { - op_index_map_.clear(); - } +void GradExecutor::ClearDynamicTopRes(const std::string &cell_id) { // Delete unused top cell resource if (!CheckDynamicCell(cell_id)) { return; } - int same_top_cell_count = 0; - for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) { - // High order should exclude - if (graph_stack_.empty() && df_builder != nullptr && (*it)->df_builder.get() != df_builder.get()) { - MS_LOG(DEBUG) << "Delete cell id " << (*it)->cell_id; - it = top_cell_list_.erase(it); + auto count = std::count_if(top_cell_list_.begin(), top_cell_list_.end(), + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id; }); + if (count < 2) { + return; + } + // Keep only one dynamic top cell + bool is_sedond_find = false; + for (auto it = top_cell_list_.begin(); it != top_cell_list_.end(); ++it) { + if ((*it)->cell_id() != cell_id) { continue; } - if ((*it)->cell_id == cell_id) { - ++same_top_cell_count; - if (same_top_cell_count > 1) { - graph_info_map_.erase((*it)->df_builder); + + if (top_cell()->is_real_dynamic()) { + MS_LOG(DEBUG) << "Real dynamic, delete first dynamic top cell"; + (*it)->clear(); + it = top_cell_list_.erase(it); + break; + } else { + if (is_sedond_find) { + MS_LOG(DEBUG) << "Fake dynamic, delete second dynamic top cell"; + (*it)->clear(); it = top_cell_list_.erase(it); - --same_top_cell_count; - } else { - ++it; + break; } - } else { - ++it; + is_sedond_find = true; } } } -bool PynativeExecutor::CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, - const py::object &sens) { +bool GradExecutor::CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, + const py::object &sens) { bool res = false; auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id; }); if (it == top_cell_list_.end()) { return res; } @@ -2600,46 +2656,44 @@ bool PynativeExecutor::CheckGradParamsChanged(const std::string &cell_id, const sens_id = fn(sens); } - if (!(*it)->sens_id.empty() && (*it)->sens_id != sens_id) { - (*it)->sens_id = sens_id; + if (!(*it)->sens_id().empty() && (*it)->sens_id() != sens_id) { + (*it)->set_sens_id(sens_id); } std::string weights_id = fn(weights); - if (!(*it)->weights_id.empty() && (*it)->weights_id != weights_id) { - (*it)->weights_id = weights_id; + if (!(*it)->weights_id().empty() && (*it)->weights_id() != weights_id) { + (*it)->set_weights_id(weights_id); res = true; } return res; } -void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id) { +void GradExecutor::SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id) { if (IsTopGraph(cell_id)) { - VectorClear>(&top_cell_list_, cell_id); + ClearCellRes(cell_id); } ResourcePtr resource = nullptr; auto ia = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id; }); if (ia != top_cell_list_.end()) { - resource = GetResource((*ia)->cell_id); + resource = GetResource((*ia)->cell_id()); MS_EXCEPTION_IF_NULL(resource); MS_LOG(DEBUG) << "Find old resource " << resource.get(); } if (resource == nullptr) { resource = std::make_shared(); - resource->results()[pipeline::kPynativeGraphId] = graph_id_++; MS_LOG(DEBUG) << "Make new resource " << resource.get(); } MS_EXCEPTION_IF_NULL(resource); FuncGraphPtr df_builder = std::make_shared(); auto graph_info = std::make_shared(cell_id); - graph_info_map_[df_builder] = graph_info; auto top_cell_info = std::make_shared(false, resource, df_builder, cell_id); - top_cell_info->top_cell_index = top_cell_index_; + top_cell()->graph_info_map()[df_builder] = graph_info; top_cell_list_.emplace_back(top_cell_info); FuncGraphPtr forward_graph = nullptr; - auto ib = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); - if (ib != cell_graph_list_.end()) { - forward_graph = (*ib)->fg; + auto ib = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); + if (ib != top_cell()->cell_graph_list().end()) { + forward_graph = (*ib)->fg(); } MS_EXCEPTION_IF_NULL(forward_graph); if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { @@ -2659,27 +2713,27 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args resource->set_func_graph(newfg); } -void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, - const std::string &cell_id) { +void GradExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, + const std::string &cell_id) { std::vector graph_before{}; bool index_find = false; - for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { - if (IsBpropGraph((*it)->cell_id) || (*it)->fg == nullptr) { + for (const auto &it : top_cell()->cell_graph_list()) { + if (IsBpropGraph(it->cell_id()) || it->fg() == nullptr) { continue; } if (index_find) { - graph_before.emplace_back((*it)->fg); + graph_before.emplace_back(it->fg()); continue; } - if ((*it)->cell_id == cell_id) { + if (it->cell_id() == cell_id) { index_find = true; - graph_before.emplace_back((*it)->fg); + graph_before.emplace_back(it->fg()); } } auto manager = Manage({forward_graph}, false); for (const auto &f : graph_before) { - auto graph_info = graph_info_map_.at(f); + auto graph_info = top_cell()->graph_info_map().at(f); MS_EXCEPTION_IF_NULL(graph_info); for (const auto &it : graph_info->params) { if (!it.second->has_default()) { @@ -2695,7 +2749,7 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param)); MS_LOG(DEBUG) << "Param name " << new_param->name() << " ptr " << new_param.get(); - auto graph_info_of_df_builder = graph_info_map_.at(df_builder); + auto graph_info_of_df_builder = top_cell()->graph_info_map().at(df_builder); MS_EXCEPTION_IF_NULL(graph_info_of_df_builder); graph_info_of_df_builder->params[it.first] = new_param; SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param); @@ -2705,7 +2759,7 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const graph_before.clear(); } -void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) { +void GradExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) { std::vector new_params; for (size_t i = 0; i < size; i++) { ParameterPtr p = std::make_shared(df_builder); @@ -2718,7 +2772,7 @@ void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const resource->manager()->SetParameters(df_builder, new_params); } -std::vector PynativeExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) { +std::vector GradExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) { std::vector w_args; if (!py::hasattr(weights, "__parameter_tuple__")) { MS_LOG(DEBUG) << "No paramter_tuple get"; @@ -2731,7 +2785,7 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh auto param = tuple[it]; auto param_id = GetId(param); AnfNodePtr para_node = nullptr; - auto graph_info = graph_info_map_.at(df_builder); + auto graph_info = top_cell()->graph_info_map().at(df_builder); MS_EXCEPTION_IF_NULL(graph_info); if (graph_info->params.find(param_id) != graph_info->params.end() && graph_info->node_map.find(param_id) != graph_info->node_map.end()) { @@ -2754,7 +2808,7 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh return w_args; } -abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder) { +abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder) { abstract::AbstractBasePtrList args_spec; std::size_t size = args.size(); auto df_params = df_builder->parameters(); @@ -2789,13 +2843,13 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args return args_spec; } -void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, - const std::vector &weights, size_t arg_size, const std::string &cell_id) { +void GradExecutor::SetGradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, + const std::vector &weights, size_t arg_size, const std::string &cell_id) { FuncGraphPtr top_g = nullptr; - auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); - if (it != cell_graph_list_.end()) { - top_g = (*it)->fg; + auto it = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); + if (it != top_cell()->cell_graph_list().end()) { + top_g = (*it)->fg(); } MS_EXCEPTION_IF_NULL(top_g); auto nparam = top_g->parameters().size(); @@ -2824,52 +2878,39 @@ void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr & resource->manager()->AddFuncGraph(df_builder); } -void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, - const std::string &cell_id) { - graph_info_map_.erase(df_builder); +void GradExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id) { + top_cell()->graph_info_map().erase(df_builder); bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); - bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(cell_id); + bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(); bool is_topmost = IsTopestGraph(cell_id); if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) { return; } MS_LOG(DEBUG) << "Update topmost cell graph list and graph info map"; - // Clear graph_info_map_ + // Clear grad()->top_cell()->graph_info_map() std::vector l{}; bool index_find = false; - auto it_end = cell_graph_list_.end(); - for (size_t i = 0; i < top_cell_list_.size(); ++i) { - if (top_cell_list_[i]->cell_id == cell_id) { - index_find = true; - continue; - } - if (index_find) { - it_end = cell_graph_list_.begin() + top_cell_list_[i]->top_cell_index; - break; - } - } - index_find = false; - for (auto it = cell_graph_list_.begin() + top_cell_index_; it != it_end; ++it) { - if ((*it)->fg != nullptr) { - std::unordered_set node_set; - ClearCnodeRes((*it)->fg->output(), &node_set); - node_set.clear(); - (*it)->fg = nullptr; + for (auto &it : top_cell()->cell_graph_list()) { + if (it->fg() != nullptr) { + ClearCnodeRes(it->fg()->output()); + it->set_fg(nullptr); } if (index_find) { - l.emplace_back((*it)->cell_id); + it->set_fg(nullptr); + l.emplace_back(it->cell_id()); continue; } - if ((*it)->cell_id == cell_id) { + if (it->cell_id() == cell_id) { index_find = true; - l.emplace_back((*it)->cell_id); + it->set_fg(nullptr); + l.emplace_back(it->cell_id()); } } for (const auto &it : l) { - for (auto ic = graph_info_map_.begin(); ic != graph_info_map_.end();) { + for (auto ic = top_cell()->graph_info_map().begin(); ic != top_cell()->graph_info_map().end();) { if (ic->second->cell_id.find(it) != std::string::npos) { - ic = graph_info_map_.erase(ic); + ic = top_cell()->graph_info_map().erase(ic); } else { ++ic; } @@ -2877,7 +2918,7 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py: } } -py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) { +py::object GradExecutor::CheckGraph(const py::object &cell, const py::args &args) { BaseRef ret = false; AddNestedGradOrder(); if (!grad_running()) { @@ -2887,14 +2928,17 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & const auto &cell_id = GetCellId(cell, args); std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); MS_LOG(DEBUG) << "Key is " << key; - for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { - MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id; - if (key != (*it)->cell_id.substr(0, std::min(PTR_LEN, (*it)->cell_id.size()))) { + for (auto it = top_cell()->cell_graph_list().begin(); it != top_cell()->cell_graph_list().end(); ++it) { + MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id(); + if (key != (*it)->cell_id().substr(0, std::min(PTR_LEN, (*it)->cell_id().size()))) { continue; } MS_LOG(DEBUG) << "Delete cellid from cell graph list"; - graph_info_map_.erase((*it)->fg); - cell_graph_list_.erase(it); + top_cell()->graph_info_map().erase((*it)->fg()); + top_cell()->cell_graph_list().erase(it); + if (IsTopestGraph(cell_id)) { + ClearCellRes(cell_id); + } ret = true; break; } @@ -2903,44 +2947,39 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) { bool forward_run = false; - const auto &cell_id = GetCellId(cell, args); - // Checkout whether top cell has already run. + const auto &cell_id = grad_executor()->GetCellId(cell, args); std::string input_args_id; for (size_t i = 0; i < args.size(); ++i) { input_args_id = input_args_id + GetId(args[i]) + "_"; } - auto top_cell = GetTopCell(cell_id); + auto top_cell = grad_executor()->GetTopCell(cell_id); if (top_cell != nullptr) { - if (!top_cell->input_args_id.empty() && top_cell->input_args_id != input_args_id && top_cell->forward_already_run && - CheckDynamicCell(cell_id)) { + if (!top_cell->input_args_id().empty() && top_cell->input_args_id() != input_args_id && + top_cell->forward_already_run() && top_cell->has_dynamic_cell()) { MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, " "forward process will run again"; - top_cell->forward_already_run = false; - top_cell->input_args_id = input_args_id; + top_cell->set_forward_already_run(false); + top_cell->set_input_args_id(input_args_id); } else { - forward_run = top_cell->forward_already_run; + forward_run = top_cell->forward_already_run() && !top_cell->is_real_dynamic(); } if (forward_run) { - top_cell_index_ = top_cell->top_cell_index; + grad_executor()->set_top_cell(top_cell); } + MS_LOG(DEBUG) << " Top cell id " << top_cell->cell_id(); } - MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ " - << top_cell_id_; + MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id; return BaseRefToPyData(forward_run); } -void PynativeExecutor::RunInner(const py::object &cell, const py::tuple &args, const py::object &phase, - py::object *ret) { +void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args, + const py::object &phase) { MS_EXCEPTION_IF_NULL(ret); auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "Run start cell id " << cell_id; - bool has_sens = false; - for (const auto &it : top_cell_list_) { - if (cell_id.find(it->cell_id) != std::string::npos && cell_id != it->cell_id) { - has_sens = true; - break; - } - } + auto has_sens = std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { + return cell_id.find(value->cell_id()) != std::string::npos && cell_id != value->cell_id(); + }); py::object forward_args = args; cell_id = GetGradCellId(has_sens, cell, args, &forward_args); MS_LOG(DEBUG) << "Run has sens " << has_sens << " forward cell id " << cell_id; @@ -2967,9 +3006,9 @@ void PynativeExecutor::RunInner(const py::object &cell, const py::tuple &args, c set_grad_runing(false); MS_LOG(DEBUG) << "Eval run end " << value.ToString(); *ret = BaseRefToPyData(value); - auto do_vm_compiled = - std::any_of(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->do_vm_compiled; }); + auto do_vm_compiled = std::any_of( + top_cell_list_.begin(), top_cell_list_.end(), + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id() == cell_id && value->vm_compiled(); }); if (do_vm_compiled) { if (MakeBpropNestedCnode(cell, *ret, cell_id)) { return; @@ -2978,7 +3017,7 @@ void PynativeExecutor::RunInner(const py::object &cell, const py::tuple &args, c } } -bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) { +bool GradExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) { if (graph_stack_.empty() || !py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { MS_LOG(DEBUG) << "No nested bprop grad find"; return false; @@ -2987,7 +3026,7 @@ bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::ob std::vector inputs; inputs.emplace_back(NewValueNode(curr_g_)); PopGraphStack(); - auto graph_info = graph_info_map_.at(curr_g_); + auto graph_info = top_cell()->graph_info_map().at(curr_g_); MS_EXCEPTION_IF_NULL(graph_info); for (const auto &ig : graph_info->params) { if (!ig.second->has_default()) { @@ -3001,8 +3040,8 @@ bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::ob return true; } -void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, - const py::object &out, bool has_sens) { +void GradExecutor::MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, + const py::object &out, bool has_sens) { if (graph_stack_.empty()) { MS_LOG(DEBUG) << "No nested grad find"; return; @@ -3031,13 +3070,13 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4); } -void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, - std::vector *inputs) { +void GradExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, + std::vector *inputs) { FuncGraphPtr forward_graph = nullptr; - auto ic = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), - [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); - if (ic != cell_graph_list_.end()) { - forward_graph = (*ic)->fg; + auto ic = std::find_if(top_cell()->cell_graph_list().begin(), top_cell()->cell_graph_list().end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id() == cell_id; }); + if (ic != top_cell()->cell_graph_list().end()) { + forward_graph = (*ic)->fg(); } MS_EXCEPTION_IF_NULL(forward_graph); auto param_list = replace_weights_map_.at(forward_graph); @@ -3061,90 +3100,117 @@ void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std:: replace_weights_map_.erase(forward_graph); } -void PynativeExecutor::Clear(const std::string &cell_id) { - if (cell_id.empty()) { - Clean(); - return; - } - - MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id; - for (auto it = graph_info_map_.begin(); it != graph_info_map_.end();) { - if (it->second->cell_id.find(cell_id) != std::string::npos) { - it = graph_info_map_.erase(it); - } else { - ++it; - } - } - // Maybe exit in runop step - auto ms_context = MsContext::GetInstance(); - if (ms_context != nullptr) { - ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); +void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) { + const auto &cell_id = GetCellId(cell, args); + if (IsTopestGraph(cell_id)) { + pre_top_cell_ = top_cell(); + set_top_cell(nullptr); + in_grad_process_ = false; } - ConfigManager::GetInstance().ResetIterNum(); - VectorClear>(&cell_graph_list_, cell_id); - VectorClear>(&top_cell_list_, cell_id); - node_abs_map_.clear(); -} - -void PynativeExecutor::Clean() { - MS_LOG(DEBUG) << "Clean"; + in_bprop_process_ = false; SubNestedGradOrder(); - node_abs_map_.clear(); + forward()->node_abs_map().clear(); obj_to_forward_id_.clear(); ad::CleanRes(); pipeline::ReclaimOptimizer(); } -void PynativeExecutor::ClearRes() { - MS_LOG(DEBUG) << "Clear all res"; - Clean(); - graph_id_ = 0; +void GradExecutor::ClearRes() { + MS_LOG(DEBUG) << "Clear grad res"; grad_order_ = 0; grad_flag_ = false; in_grad_process_ = false; - in_bprop_process_ = false; has_dynamic_cell_ = false; - grad_is_running_ = false; need_replace_forward_ = true; + grad_is_running_ = false; + pre_top_cell_ = nullptr; + top_cell_ = nullptr; curr_g_ = nullptr; - - top_cell_id_.clear(); - graph_info_map_.clear(); - replace_weights_map_.clear(); - cell_graph_list_.clear(); - top_cell_list_.clear(); - cell_input_args_.clear(); op_index_map_.clear(); - cell_op_index_with_tensor_id_.clear(); - cell_tensor_id_with_tensor_.clear(); - prim_abs_list_.clear(); + replace_weights_map_.clear(); all_value_node_tensors_.clear(); + obj_to_forward_id_.clear(); + ClearCellRes(); + top_cell_list_.clear(); std::stack().swap(graph_stack_); std::stack().swap(cell_op_info_stack_); } +GradExecutorPtr PynativeExecutor::grad_executor() { + MS_EXCEPTION_IF_NULL(grad_executor_); + return grad_executor_; +} +ForwardExecutorPtr PynativeExecutor::forward_executor() { + MS_EXCEPTION_IF_NULL(forward_executor_); + return forward_executor_; +} + +void PynativeExecutor::set_grad_flag(bool flag) { grad_executor()->set_grad_flag(flag); } + +bool PynativeExecutor::GetIsDynamicCell() { + if (grad_executor_ == nullptr) { + return false; + } + return grad_executor_->TopCellIsDynamic(); +} + +py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) { + return grad_executor()->CheckGraph(cell, args); +} + +py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { + py::object ret; + PynativeExecutorTry(grad_executor()->RunGraph, &ret, cell, args, phase); + return ret; +} + +void PynativeExecutor::ClearCell(const std::string &cell_id) { + MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id; + grad_executor()->ClearCellRes(cell_id); +} + +void PynativeExecutor::ClearGrad(const py::object &cell, const py::args &args) { + MS_LOG(DEBUG) << "Clear grad"; + return grad_executor()->ClearGrad(cell, args); +} + +void PynativeExecutor::ClearRes() { + MS_LOG(DEBUG) << "Clear all res"; + // Maybe exit in runop step + auto ms_context = MsContext::GetInstance(); + if (ms_context != nullptr) { + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); + } + ConfigManager::GetInstance().ResetIterNum(); + if (forward_executor_ != nullptr) { + forward_executor_->ClearRes(); + } + if (grad_executor_ != nullptr) { + grad_executor_->ClearRes(); + } + ad::CleanRes(); + pipeline::ReclaimOptimizer(); +} + void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { - PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args); + py::object *ret = nullptr; + PynativeExecutorTry(grad_executor()->InitGraph, ret, cell, args); } void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { MS_LOG(DEBUG) << "Enter end graph process."; + py::object *ret = nullptr; auto &mem_cleaner = pipeline::Resource::mem_cleaner(); mem_cleaner.EnterPynativeEndGraphProcess(); - PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); + PynativeExecutorTry(grad_executor()->LinkGraph, ret, cell, out, args); mem_cleaner.LeavePynativeEndGraphProcess(); MS_LOG(DEBUG) << "Leave end graph process."; } void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args) { - PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); -} - -py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { - py::object ret; - PynativeExecutorTry(this, &PynativeExecutor::RunInner, cell, args, phase, &ret); - return ret; + py::object *ret = nullptr; + PynativeExecutorTry(grad_executor()->GradGraph, ret, grad, cell, weights, args); } void PynativeExecutor::Sync() { @@ -3155,19 +3221,19 @@ void PynativeExecutor::Sync() { } void PynativeExecutor::EnterConstruct(const py::object &cell) { - if (top_cell_ != nullptr) { + if (py_top_cell_ != nullptr) { return; } - top_cell_ = cell.ptr(); + py_top_cell_ = cell.ptr(); pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess(); MS_LOG(DEBUG) << "Enter construct process."; } void PynativeExecutor::LeaveConstruct(const py::object &cell) { - if (top_cell_ != cell.ptr()) { + if (py_top_cell_ != cell.ptr()) { return; } - top_cell_ = nullptr; + py_top_cell_ = nullptr; pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess(); MS_LOG(DEBUG) << "Leave construct process."; } @@ -3180,7 +3246,8 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.") .def("check_run", &PynativeExecutor::CheckAlreadyRun, "pynative check graph run before.") .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.") - .def("clear", &PynativeExecutor::Clear, "pynative clear status.") + .def("clear_cell", &PynativeExecutor::ClearCell, "pynative clear status.") + .def("clear_grad", &PynativeExecutor::ClearGrad, "pynative clear grad status.") .def("sync", &PynativeExecutor::Sync, "pynative sync stream.") .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 8df1a77224..94d5ab1e0a 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -17,8 +17,8 @@ #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ -#include #include +#include #include #include #include @@ -68,118 +68,121 @@ struct GraphInfo { GraphInfo() = default; explicit GraphInfo(std::string id) : cell_id(std::move((id))) {} }; +using GraphInfoPtr = std::shared_ptr; class CellInfo { public: CellInfo() = default; + ~CellInfo() = default; CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id) - : is_custom_bprop(custom_bprop), - is_dynamic(has_dynamic), - fg(std::move(foward_graph)), - cell_id(std::move(cellid)), - bprop_cell_id(std::move(bprop_id)) {} - - bool is_grad{false}; // Derivative is calculated - bool is_custom_bprop{false}; // Custom bprop - bool is_dynamic{false}; // Set by has_dynamic_cell - bool is_real_dynamic{false}; // Set by ops order - size_t call_times{0}; - FuncGraphPtr fg{nullptr}; // Forward graph - std::string cell_id; - std::string bprop_cell_id; - std::vector cell_ops_info; // All ops info + : is_custom_bprop_(custom_bprop), + is_dynamic_(has_dynamic), + fg_(std::move(foward_graph)), + cell_id_(std::move(cellid)), + bprop_cell_id_(std::move(bprop_id)) {} + + bool is_custom_bprop() const { return is_custom_bprop_; } + void set_is_custom_bprop(bool is_custom_bprop) { is_custom_bprop_ = is_custom_bprop; } + bool is_dynamic() const { return is_dynamic_; } + void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; } + size_t call_times() const { return call_times_; } + void set_call_times(size_t call_times) { call_times_ = call_times; } + FuncGraphPtr fg() const { return fg_; } + void set_fg(FuncGraphPtr fg) { fg_ = std::move(fg); } + std::string &cell_id() { return cell_id_; } + void set_cell_id(std::string cell_id) { cell_id_ = std::move(cell_id); } + std::string &bprop_cell_id() { return bprop_cell_id_; } + std::vector &cell_ops_info() { return cell_ops_info_; } + + private: + bool is_custom_bprop_{false}; // Custom bprop + bool is_dynamic_{false}; // Set by has_dynamic_cell + size_t call_times_{0}; + FuncGraphPtr fg_{nullptr}; // Forward graph + std::string cell_id_; + std::string bprop_cell_id_; + std::vector cell_ops_info_; // All ops info }; +using CellInfoPtr = std::shared_ptr; class TopCellInfo { public: TopCellInfo() = default; + ~TopCellInfo() = default; TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid) - : is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {} - - bool need_grad{true}; - bool is_topest{false}; - bool do_vm_compiled{false}; - bool forward_already_run{false}; - size_t top_cell_index{0}; - ResourcePtr resource{nullptr}; - FuncGraphPtr df_builder{nullptr}; - FuncGraphPtr bg{nullptr}; // Backward graph - std::string cell_id; - std::string sens_id; - std::string weights_id; - std::string input_args_id; -}; - -using GraphInfoPtr = std::shared_ptr; -using CellInfoPtr = std::shared_ptr; -using TopCellInfoPtr = std::shared_ptr; - -class PynativeExecutor : public std::enable_shared_from_this { - public: - static std::shared_ptr GetInstance() { - std::lock_guard i_lock(instance_lock_); - if (executor_ == nullptr) { - executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); - } - return executor_; + : is_topest_(topest), resource_(std::move(r)), df_builder_(std::move(df)), cell_id_(std::move(cellid)) {} + + bool is_grad() const { return is_grad_; } + void set_is_grad(bool is_grad) { is_grad_ = is_grad; } + bool is_topest() const { return is_topest_; } + bool vm_compiled() const { return vm_compiled_; } + void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; } + bool need_grad() const { return need_grad_; } + void set_need_grad(bool need_grad) { need_grad_ = need_grad; } + bool has_dynamic_cell() const { return has_dynamic_cell_; } + bool is_real_dynamic() const { return is_real_dynamic_; } + void set_is_real_dynamic(bool is_real_dynamic) { is_real_dynamic_ = is_real_dynamic; } + bool forward_already_run() const { return forward_already_run_; } + void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; } + ResourcePtr resource() { return resource_; } + FuncGraphPtr df_builder() { return df_builder_; } + std::string &cell_id() { return cell_id_; } + std::string &sens_id() { return sens_id_; } + void set_sens_id(std::string sens_id) { sens_id_ = std::move(sens_id); } + std::string &weights_id() { return weights_id_; } + void set_weights_id(std::string weights_id) { weights_id_ = std::move(weights_id); } + std::string &input_args_id() { return input_args_id_; } + void set_input_args_id(std::string input_args_id) { input_args_id_ = std::move(input_args_id); } + std::vector &cell_graph_list() { return cell_graph_list_; } + void set_cell_graph_list(const std::vector &cell_graph_list) { cell_graph_list_ = cell_graph_list; } + OrderedMap &graph_info_map() { return graph_info_map_; } + void set_graph_info_map(const OrderedMap &graph_info_map) { + graph_info_map_ = graph_info_map; + } + void clear() { + cell_graph_list_.clear(); + graph_info_map_.clear(); } - ~PynativeExecutor(); - PynativeExecutor(const PynativeExecutor &) = delete; - PynativeExecutor &operator=(const PynativeExecutor &) = delete; - - bool need_replace_forward() const { return need_replace_forward_; } - bool grad_flag() const { return grad_flag_; } - void set_grad_flag(bool flag) { grad_flag_ = flag; } - void EnterConstruct(const py::object &cell); - void LeaveConstruct(const py::object &cell); - py::object RunOpInner(const OpExecInfoPtr &op_exec_info); - OpExecInfoPtr GenerateOpExecInfo(const py::args &args); - void NewGraph(const py::object &cell, const py::args &args); - py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase); - void RunInner(const py::object &cell, const py::tuple &args, const py::object &phase, py::object *ret); - py::object CheckGraph(const py::object &cell, const py::args &args); - py::object CheckAlreadyRun(const py::object &cell, const py::args &args); - void EndGraph(const py::object &cell, const py::object &out, const py::args &args); - void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); + private: + bool is_grad_{false}; // Derivative is calculated + bool is_topest_{false}; + bool vm_compiled_{false}; + bool need_grad_{true}; + bool has_dynamic_cell_{false}; + bool is_real_dynamic_{false}; + bool forward_already_run_{false}; + ResourcePtr resource_{nullptr}; + FuncGraphPtr df_builder_{nullptr}; + std::string cell_id_; + std::string sens_id_; + std::string weights_id_; + std::string input_args_id_; + std::vector cell_graph_list_; + OrderedMap graph_info_map_; +}; +using TopCellInfoPtr = std::shared_ptr; - // Get info - bool GetIsDynamicCell() { return CheckRealDynamicCell(top_cell_id_); } - // Call by python - void Clear(const std::string &flag = ""); - void Clean(); - // Abnormal existed - void ClearRes(); - // Sync stream - void Sync(); +class DynamicAnalysis; +using DynamicAnalysisPtr = std::shared_ptr; - private: - PynativeExecutor() = default; +class ForwardExecutor; +using ForwardExecutorPtr = std::shared_ptr; +using ForwardExecutorWeakPtr = std::weak_ptr; - template - void MapClear(T *map, const std::string &cell_id) { - for (auto it = map->begin(); it != map->end();) { - if (it->first.find(cell_id) != std::string::npos) { - it = map->erase(it); - } else { - it++; - } - } - } +class GradExecutor; +using GradExecutorPtr = std::shared_ptr; +using GradExecutorWeakPtr = std::weak_ptr; - template - void VectorClear(T *vec, const std::string &cell_id) { - for (auto it = vec->begin(); it != vec->end();) { - if ((*it)->cell_id.find(cell_id) != std::string::npos) { - it = vec->erase(it); - } else { - it++; - } - } - } +class DynamicAnalysis { + public: + DynamicAnalysis() = default; + ~DynamicAnalysis() = default; // Check cell struct bool IsDynamicCell(const py::object &cell); + + private: std::string GetCellInfo(const py::object &cell); void ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node); bool ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node, @@ -191,78 +194,113 @@ class PynativeExecutor : public std::enable_shared_from_this { bool ParseForExprNode(const std::shared_ptr &ast, const py::object &node); std::string ParseNodeName(const std::shared_ptr &ast, const py::object &node, parse::AstMainType type); - py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, size_t index); - py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple, const std::string &op_name, - size_t index); - py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index); - void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map &dst_type, - const std::vector &dtypes, const OpExecInfoPtr &op_exec_info); - // Run op + + std::unordered_set cell_input_args_; +}; + +class GradExecutor { + public: + GradExecutor() = default; + ~GradExecutor() = default; + explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr) + : forward_executor_(ForwardExecutorWeakPtr(forward_executor)) {} + + std::function InitGraph = [this](auto &&PH1, auto &&PH2, + auto &&PH3) { + NewGraphInner(std::forward(PH1), std::forward(PH2), std::forward(PH3)); + }; + std::function LinkGraph = + [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4) { + EndGraphInner(std::forward(PH1), std::forward(PH2), + std::forward(PH3), std::forward(PH4)); + }; + std::function + GradGraph = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5) { + GradNetInner(std::forward(PH1), std::forward(PH2), std::forward(PH3), + std::forward(PH4), std::forward(PH5)); + }; + std::function RunGraph = + [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4) { + RunGradGraph(std::forward(PH1), std::forward(PH2), std::forward(PH3), + std::forward(PH4)); + }; + + FuncGraphPtr curr_g() const; + TopCellInfoPtr top_cell() const; + bool TopCellIsDynamic(); + void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); } + bool grad_flag() const { return grad_flag_; } + void set_grad_flag(bool flag) { grad_flag_ = flag; } + bool in_grad_process() const { return in_grad_process_; } + std::string top_cell_id() { return top_cell()->cell_id(); } AnfNodePtr GetInput(const py::object &obj, bool op_mask); - MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info); - py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info); - void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info); - py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); - py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); - py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, - PynativeStatusCode *const status); - AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); - AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); - void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, std::vector *inputs, - abstract::AbstractBasePtrList *args_spec_list); - AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, - abstract::AbstractBasePtrList *args_spec_list); - abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, - const abstract::AbstractBasePtr &abs, const std::string &id, size_t index); - void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, - bool *is_find); + std::string GetCellId(const py::object &obj, const py::args &args); + TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false); void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode); - - // Replace for grad graph - ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple); - void GenTupleMap(const ValueTuplePtr &tuple, std::map *t_map); void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real); - // Update the abstract and device address info of value node and tensors in bprop graph - void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); - void SaveTensorsInValueNode(const ResourcePtr &resource); - void SaveAllValueNodeTensors(const FuncGraphPtr &graph); - void CleanPreMemoryInValueNode(); + py::object CheckGraph(const py::object &cell, const py::args &args); + void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args, const py::object &phase); + bool need_construct_graph() const { return !graph_stack_.empty() && grad_flag_; } + void set_dynamic_analysis(DynamicAnalysisPtr dynamic_analysis) { dynamic_analysis_ = std::move(dynamic_analysis); } + std::stack &graph_stack() { return graph_stack_; } + std::vector &top_cell_list() { return top_cell_list_; } + bool need_replace_forward() const { return need_replace_forward_; } + std::stack &cell_op_info_stack() { return cell_op_info_stack_; } + std::unordered_map &op_index_map() { return op_index_map_; } + std::unordered_map &obj_to_forward_id() { return obj_to_forward_id_; } + void ClearGrad(const py::object &cell, const py::args &args); + void ClearRes(); + void ClearCellRes(const std::string &cell_id = ""); + + private: + ForwardExecutorPtr forward() const; + DynamicAnalysisPtr dynamic_analysis() const; + bool grad_running() const { return grad_is_running_; } + void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; } + void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; } + + // Higher derivative + bool IsNestedGrad() const; + void AddNestedGradOrder() { ++grad_order_; } + void SubNestedGradOrder(); + void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, + const std::string &cell_id); + void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id); + void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, + const py::object &out, bool has_sens); + void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector *inputs); + bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id); + + // Dynamic + bool CheckDynamicCell(const std::string &cell_id); + bool CheckRealDynamicCell(const std::string &cell_id); + void ClearDynamicTopRes(const std::string &cell_id); - // Construct grad graph void PushCurrentGraphToStack(); void PopGraphStack(); void PushCurrentCellOpInfoToStack(); void PopCurrentCellOpInfoFromStack(); + std::string GetCellOpInfo(); + void ReplaceCellOpInfoByCellId(const std::string &cell_id); + FuncGraphPtr GetDfbuilder(const std::string &cell_id = ""); ResourcePtr GetResource(const std::string &cell_id = ""); - void AddNestedGradOrder() { ++grad_order_; } - void SubNestedGradOrder(); - bool IsNestedGrad() const; + bool IsFirstGradStep(); bool IsTopGraph(const std::string &cell_id); bool IsTopestGraph(const std::string &cell_id); bool IsBpropGraph(const std::string &cell_id); - bool IsFirstGradStep(const std::string &cell_id); - bool grad_running() const { return grad_is_running_; } - void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; } - void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; } - bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; } - bool CheckCellGraph(const std::string &cell_id, bool is_grad = false); - bool CheckDynamicCell(const std::string &cell_id); - bool CheckRealDynamicCell(const std::string &cell_id); + bool IsGradBefore(const std::string &cell_id); + bool CheckCellGraph(const std::string &cell_id); bool UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, bool is_grad); void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned = false, bool is_grad = false); - void ClearCnodeRes(const AnfNodePtr &node, std::unordered_set *node_set); - void UpdateCellDynamic(const std::string &cell_id); bool CheckCellChanged(const std::string &cell_id); void UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled); - void ClearResidualRes(const std::string &cell_id); void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); - void NewGraphInner(const py::object &cell, const py::args &args); + void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args); void MakeNewTopGraph(const string &cell_id, const py::args &args); - TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false); - void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); + void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args); void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, const std::string &out_id, const py::args &args); bool EndBpropGraph(const string &cell_id); @@ -270,90 +308,185 @@ class PynativeExecutor : public std::enable_shared_from_this { const std::string &cell_id, const py::args &args); std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args, py::object *sens = nullptr); - void ClearDynamicTopRes(const std::string &cell_id, const FuncGraphPtr &df_builder); - void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + void GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); - std::string GetCellId(const py::object &obj, const py::args &args); - std::string GetTensorCellId(const std::string &cell_id); + void SetTopCellTensorId(const std::string &cell_id); bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens); void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size); - void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector &weights, - size_t arg_size, const std::string &cell_id); + void SetGradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector &weights, + size_t arg_size, const std::string &cell_id); std::vector GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder); abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder); void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id); - void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, - const std::string &cell_id); - void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id); - void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, - const py::object &out, bool has_sens); - void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector *inputs); - bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id); + void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, + const std::vector &index_sequence, bool is_param = false); + AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); + AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); + + // Memory clean between steps + void ClearResidualRes(const std::string &cell_id); + void ClearCnodeRes(const AnfNodePtr &node); + void CleanPreMemoryInValueNode(); + void SaveTensorsInValueNode(const ResourcePtr &resource); + void SaveAllValueNodeTensors(const FuncGraphPtr &graph); - // Hold graph(forward and grad) info - std::string GetCellOpInfo(); - void ReplaceCellOpInfoByCellId(const std::string &cell_id); void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) { - graph_info_map_[g]->objects.push_back(obj); + top_cell()->graph_info_map()[g]->objects.push_back(obj); } void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, bool is_param = false); void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) { - graph_info_map_[g]->params[id] = param; + top_cell()->graph_info_map()[g]->params[id] = param; } void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, int64_t index = -1) { - graph_info_map_[g]->node_map[id] = std::make_pair(node, std::vector{index}); + top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, std::vector{index}); } void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, const std::vector &index) { - graph_info_map_[g]->node_map[id] = std::make_pair(node, index); + top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index); } - void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, - const std::vector &index_sequence, bool is_param = false); - static std::shared_ptr executor_; - static std::mutex instance_lock_; - static int64_t graph_id_; + private: size_t grad_order_{0}; - size_t top_cell_index_{0}; - std::string top_cell_id_; bool grad_flag_{false}; bool in_bprop_process_{false}; bool in_grad_process_{false}; bool has_dynamic_cell_{false}; - bool grad_is_running_{false}; bool need_replace_forward_{true}; - // The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, - // such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global - // primitives to control memory release. Global primitives are always created in top cell's '__init__' function and - // temporary primitives are always created in other place.Temporary primitives will be released after executing top - // cell's 'construct' function but global primitives will not. - PyObject *top_cell_{nullptr}; - - // Used for construct grad graph + bool grad_is_running_{false}; FuncGraphPtr curr_g_{nullptr}; + // For clear pre top res + TopCellInfoPtr pre_top_cell_{nullptr}; + TopCellInfoPtr top_cell_{nullptr}; + std::unordered_map op_index_map_; + std::unordered_map>> replace_weights_map_; + std::unordered_set all_value_node_tensors_; + std::unordered_map obj_to_forward_id_; + // Records forwrad graph, the bottom is top graph std::stack graph_stack_; // Records op info of every cell, the bottom is op info of top cell std::stack cell_op_info_stack_; // Use vector for keep order - std::vector cell_graph_list_; std::vector top_cell_list_; - std::unordered_set cell_input_args_; - // Record all info for all cells - OrderedMap graph_info_map_; - std::unordered_map>> replace_weights_map_; + ForwardExecutorWeakPtr forward_executor_; + DynamicAnalysisPtr dynamic_analysis_; +}; + +class ForwardExecutor { + public: + ForwardExecutor() = default; + ~ForwardExecutor() = default; + std::function RunOpS = [this](auto &&PH1, auto &&PH2) { + RunOpInner(std::forward(PH1), std::forward(PH2)); + }; + + void RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info); + OpExecInfoPtr GenerateOpExecInfo(const py::args &args); + void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); } + std::unordered_map &node_abs_map() { return node_abs_map_; } + std::unordered_map &cell_op_index_with_tensor_id() { + return cell_op_index_with_tensor_id_; + } + std::unordered_map &cell_tensor_id_with_tensor() { + return cell_tensor_id_with_tensor_; + } + void ClearRes(); + + private: + GradExecutorPtr grad() const; + MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info); + py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info); + py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); + py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); + py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, + PynativeStatusCode *status); + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, + abstract::AbstractBasePtrList *args_spec_list); + void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, std::vector *inputs, + abstract::AbstractBasePtrList *args_spec_list); + abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, + const abstract::AbstractBasePtr &abs, const std::string &id, size_t index); + void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, + bool *is_find); + // Update the abstract and device address info of value node and tensors in bprop graph + void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); + + // Mix precision + void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info); + py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index); + py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, const std::string &op_name, + size_t index); + py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index); + void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map &dst_type, + const std::vector &dtypes, const OpExecInfoPtr &op_exec_info); + + private: + GradExecutorWeakPtr grad_executor_; + std::unordered_map prim_abs_list_; + std::unordered_map node_abs_map_; // Used for runop and replace forward result of grad graph - std::unordered_map op_index_map_; - std::unordered_map obj_to_forward_id_; std::unordered_map cell_op_index_with_tensor_id_; std::unordered_map cell_tensor_id_with_tensor_; - std::unordered_map node_abs_map_; - std::unordered_map prim_abs_list_; - std::unordered_set all_value_node_tensors_; +}; + +class PynativeExecutor : public std::enable_shared_from_this { + public: + static std::shared_ptr GetInstance() { + std::lock_guard i_lock(instance_lock_); + if (executor_ == nullptr) { + executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); + forward_executor_ = std::make_shared(); + grad_executor_ = std::make_shared(forward_executor_); + grad_executor_->set_dynamic_analysis(std::make_shared()); + forward_executor_->set_grad_executor(grad_executor_); + } + return executor_; + } + ~PynativeExecutor() = default; + PynativeExecutor(const PynativeExecutor &) = delete; + PynativeExecutor &operator=(const PynativeExecutor &) = delete; + + void EnterConstruct(const py::object &cell); + void LeaveConstruct(const py::object &cell); + GradExecutorPtr grad_executor(); + ForwardExecutorPtr forward_executor(); + + void set_grad_flag(bool flag); + void NewGraph(const py::object &cell, const py::args &args); + void EndGraph(const py::object &cell, const py::object &out, const py::args &args); + void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); + py::object CheckGraph(const py::object &cell, const py::args &args); + py::object CheckAlreadyRun(const py::object &cell, const py::args &args); + py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase); + + // Used by graph clean + bool GetIsDynamicCell(); + bool need_replace_forward() { return grad_executor()->need_replace_forward(); } + // Cell destruct will call + void ClearCell(const std::string &flag = ""); + void ClearGrad(const py::object &cell, const py::args &args); + // Abnormal existed + void ClearRes(); + // Sync stream + void Sync(); + + private: + PynativeExecutor() = default; + + static std::shared_ptr executor_; + static std::mutex instance_lock_; + static ForwardExecutorPtr forward_executor_; + static GradExecutorPtr grad_executor_; + // The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, + // such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global + // primitives to control memory release. Global primitives are always created in top cell's '__init__' function and + // temporary primitives are always created in other place.Temporary primitives will be released after executing top + // cell's 'construct' function but global primitives will not. + PyObject *py_top_cell_{nullptr}; }; using PynativeExecutorPtr = std::shared_ptr; diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 486f176d0f..4babfd256c 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -367,8 +367,11 @@ class _PynativeExecutor: def grad(self, grad, obj, weights, *args, **kwargs): self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values())) - def clear(self, cell_id=""): - self._executor.clear(cell_id) + def del_cell(self, cell_id=""): + self._executor.clear_cell(cell_id) + + def clear_grad(self, obj, *args, **kwargs): + self._executor.clear_grad(obj, *args, *(kwargs.values())) def sync(self): self._executor.sync() diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 21a7f9fdda..96e58e8850 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -267,7 +267,7 @@ class Cell(Cell_): def __del__(self): if context.get_context is not None and context.get_context("mode") == context.PYNATIVE_MODE: - _pynative_exec.clear(str(id(self))) + _pynative_exec.del_cell(str(id(self))) if hasattr(self, "_create_time"): _executor.del_net_res(str(self._create_time)) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 8bda4e3768..f9004fbf7c 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -370,7 +370,7 @@ class GradOperation(GradOperation_): self._pynative_forward_run(args, kwargs, fn) _pynative_exec.grad(grad_, fn, weights, *args, **kwargs) out = _pynative_exec(fn, *args, **kwargs) - _pynative_exec.clear() + _pynative_exec.clear_grad(fn, *args, **kwargs) return out self.grad_fn = after_grad self.fn = fn diff --git a/tests/ut/cpp/pynative/pynative_execute_test.cc b/tests/ut/cpp/pynative/pynative_execute_test.cc index 106de71a3d..25fafa8041 100644 --- a/tests/ut/cpp/pynative/pynative_execute_test.cc +++ b/tests/ut/cpp/pynative/pynative_execute_test.cc @@ -65,7 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() { py::none py_none; py::args args = py::make_tuple(conv_obj, op_name, op_inputs); py::list args_input = args[PY_INPUTS]; - return PynativeExecutor::GetInstance()->GenerateOpExecInfo(args); + return PynativeExecutor::GetInstance()->forward_executor()->GenerateOpExecInfo(args); } TEST_F(TestPynativeExecute, TestCreateContext) {