From ccedf71de5c3ed593c6650e2c1e377cd11e2f20f Mon Sep 17 00:00:00 2001 From: zjun Date: Thu, 29 Oct 2020 10:43:09 +0800 Subject: [PATCH] Add pynative bprop nested grad Signed-off-by: zjun --- .../pipeline/pynative/pynative_execute.cc | 1759 +++++++++-------- .../pipeline/pynative/pynative_execute.h | 160 +- mindspore/common/api.py | 3 + mindspore/nn/cell.py | 3 +- mindspore/ops/composite/base.py | 3 + .../pynative_mode/test_stop_gradient.py | 2 +- 6 files changed, 1020 insertions(+), 910 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index b4152e8a39..0c273c645c 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -63,7 +64,7 @@ using mindspore::tensor::TensorPy; -const char SINGLE_OP_GRAPH[] = "single_op_graph"; +const size_t PTR_LEN = 15; // primitive unable to infer value for constant input in PyNative mode const std::set vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient", "mixed_precision_cast"}; @@ -73,11 +74,11 @@ namespace pynative { static std::shared_ptr session = nullptr; PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; std::mutex PynativeExecutor::instance_lock_; -ResourcePtr PynativeExecutor::resource_; int PynativeExecutor::graph_id_ = 0; template void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) { + MS_EXCEPTION_IF_NULL(executor); try { (executor->*method)(args...); } catch (const py::error_already_set &ex) { @@ -115,42 +116,35 @@ void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecuto inline ValuePtr PyAttrValue(const py::object &obj) { ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj); if (!converted_ret) { - MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); + MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); } return converted_ret; } static std::string GetId(const py::object &obj) { - py::object to_process = obj; - std::string prefix = ""; - if (py::isinstance(to_process) || py::isinstance(to_process)) { - auto p_list = py::cast(to_process); + if (py::isinstance(obj)) { + auto tensor_ptr = py::cast(obj); + return tensor_ptr->id(); + } else if (py::isinstance(obj)) { + auto type_ptr = py::cast(obj); + return "type" + type_ptr->ToString(); + } else if (py::isinstance(obj) || py::isinstance(obj) || py::isinstance(obj)) { + return std::string(py::str(obj)); + } else if (py::isinstance(obj)) { + return "none"; + } else if (py::isinstance(obj) || py::isinstance(obj)) { + auto p_list = py::cast(obj); + string prefix = py::isinstance(obj) ? "tuple:" : "list"; if (p_list.empty()) { - return "empty"; - } - prefix = py::isinstance(to_process) ? "tuple:" : "list"; - std::string key = ""; - for (size_t i = 0; i < p_list.size(); ++i) { - key += std::string(py::str(GetId(p_list[i]))) + ":"; + prefix = "empty"; + } else { + std::string key; + for (size_t i = 0; i < p_list.size(); ++i) { + key += std::string(py::str(GetId(p_list[i]))) + ":"; + } + prefix += key; } - return prefix + key; - } - if (py::isinstance(to_process)) { - auto type_ptr = py::cast(to_process); - return "type" + type_ptr->ToString(); - } - if (py::isinstance(to_process)) { - return "s" + std::string(py::str(to_process)); - } - if (py::isinstance(to_process)) { - return prefix + std::string(py::str(to_process)); - } - if (py::isinstance(to_process)) { - return prefix + std::string(py::str(to_process)); - } - if (py::isinstance(to_process)) { - auto tensor_ptr = py::cast(to_process); - return prefix + tensor_ptr->id(); + return prefix; } py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj); @@ -168,9 +162,9 @@ std::map> GetTypeIndex(const std::vector for (size_t i = 0; i < dtypes.size(); ++i) { auto it = type_indexes.find(dtypes[i]); if (it == type_indexes.end()) { - (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector{i})); + (void)type_indexes.emplace(std::make_pair(dtypes[i], std::vector{i})); } else { - it->second.push_back(i); + it->second.emplace_back(i); } } return type_indexes; @@ -230,7 +224,7 @@ std::map GetDstType(const py::tuple &py_args, if (max_type == TypeId::kNumberTypeUInt8 && has_int8) { max_type = TypeId::kNumberTypeInt16; } - (void)dst_type.insert(std::make_pair(type, max_type)); + (void)dst_type.emplace(std::make_pair(type, max_type)); } return dst_type; } @@ -262,13 +256,14 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { } py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { + MS_EXCEPTION_IF_NULL(is_cast); auto tensor = py::cast(obj); auto cast_type = tensor->cast_dtype(); py::object cast_output = obj; if (cast_type != nullptr) { auto source_element = tensor->Dtype(); if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { - MS_LOG(DEBUG) << "cast to " << cast_type->ToString(); + MS_LOG(DEBUG) << "Cast to " << cast_type->ToString(); cast_output = DoAutoCast(obj, cast_type->type_id()); *is_cast = true; } @@ -277,12 +272,13 @@ py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { } py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { + MS_EXCEPTION_IF_NULL(is_cast); auto tuple_size = static_cast(tuple.size()); py::tuple result(tuple_size); for (int i = 0; i < tuple_size; i++) { if (py::isinstance(tuple[i])) { - MS_LOG(DEBUG) << "call cast for item " << i; + MS_LOG(DEBUG) << "Call cast for item " << i; result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]); } else if (py::isinstance(tuple[i]) || py::isinstance(tuple[i])) { result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]); @@ -290,7 +286,7 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { result[i] = tuple[i]; } } - return result; + return std::move(result); } bool GetSignatureType(const PrimitivePyPtr &prim, std::vector *dtypes) { @@ -312,12 +308,12 @@ void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map &dtypes, const OpExecInfoPtr &op_exec_info) { const auto &signature = prim->signatures(); auto &out_args = op_exec_info->op_inputs; - bool has_dtype_sig = (dtypes.size() > 0); + bool has_dtype_sig = !dtypes.empty(); for (size_t i = 0; i < out_args.size(); ++i) { - MS_LOG(DEBUG) << "check inputs " << i; + MS_LOG(DEBUG) << "Check inputs " << i; auto obj = out_args[i]; auto sig = SignatureEnumRW::kRWDefault; - if (signature.size() > 0) { + if (!signature.empty()) { sig = signature[i].rw; } bool is_parameter = false; @@ -326,7 +322,7 @@ void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map(obj); if (arg->is_parameter()) { is_parameter = true; - MS_LOG(DEBUG) << "parameter is read " << i; + MS_LOG(DEBUG) << "Parameter is read " << i; } arg_type_id = arg->data_type(); } @@ -373,12 +369,12 @@ void DoSignatrueCast(const PrimitivePyPtr &prim, const std::mapname() << " input infer " << mindspore::ToString(args_spec_list); + MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list); prim->BeginRecordAddAttr(); AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); prim->EndRecordAddAttr(); op_exec_info->abstract = infer_res; - MS_LOG(DEBUG) << "prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString(); + MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString(); } OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { @@ -391,7 +387,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { op_exec_info->op_name = py::cast(args[PY_NAME]); auto prim = py::cast(args[PY_PRIM]); if (!prim->HasPyObj()) { - MS_LOG(EXCEPTION) << "pyobj is empty"; + MS_LOG(EXCEPTION) << "Pyobj is empty"; } op_exec_info->py_primitive = prim; op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); @@ -499,7 +495,7 @@ void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr } auto tensor = py::cast(input_object); MS_EXCEPTION_IF_NULL(tensor); - input_tensors->push_back(tensor); + input_tensors->emplace_back(tensor); } op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector{SizeToInt(tuple_inputs.size())})); } @@ -515,7 +511,7 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vectorpush_back(tensor_ptr); + input_tensors->emplace_back(tensor_ptr); } void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, @@ -528,7 +524,7 @@ void ConvertMultiPyObjectToTensor(const py::object &input_object, const Primitiv MS_LOG(EXCEPTION) << "The input should be a tuple!"; } auto tuple_inputs = py::cast(input_object); - if (tuple_inputs.size() == 0) { + if (tuple_inputs.empty()) { MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!"; } auto inputs = py::cast(input_object); @@ -574,7 +570,7 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; } MS_EXCEPTION_IF_NULL(tensor_ptr); - input_tensors->push_back(tensor_ptr); + input_tensors->emplace_back(tensor_ptr); } void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *tensors_mask, @@ -625,7 +621,7 @@ void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector new_input_tensors; for (size_t index = 0; index < tensors_mask.size(); ++index) { if (tensors_mask[index] != kValueNodeTensorMask) { - new_input_tensors.push_back(input_tensors->at(index)); + new_input_tensors.emplace_back(input_tensors->at(index)); } } *input_tensors = new_input_tensors; @@ -657,92 +653,155 @@ BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) { } } -py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; - auto ms_context = MsContext::GetInstance(); - ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, true); - std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - if (device_target != kAscendDevice && device_target != kGPUDevice) { - MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; +size_t GetTupleSize(const py::tuple &args) { + size_t count = 0; + for (size_t i = 0; i < args.size(); i++) { + if (py::isinstance(args[i])) { + count += GetTupleSize(args[i]); + } else { + count += 1; + } } + return count; +} - if (session == nullptr) { - session = session::SessionFactory::Get().Create(device_target); - MS_EXCEPTION_IF_NULL(session); - session->Init(ms_context->get_param(MS_CTX_DEVICE_ID)); +void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) { + for (size_t i = 0; i < arg.size(); i++) { + if (py::isinstance(arg[i])) { + ConvertTupleArg(res, index, arg[i]); + } else { + (*res)[(*index)++] = arg[i]; + } } +} - std::vector input_tensors; - std::vector tensors_mask; - ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); - // get graph info for checking it whether existing in the cache - std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); - session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, op_exec_info->abstract, - op_exec_info->value}; - session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); - EraseValueNodeTensor(tensors_mask, &input_tensors); - VectorRef outputs; - session->RunOp(&op_run_info, graph_info, input_tensors, &outputs); - auto result = BaseRefToPyData(outputs); - ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; - return result; +py::tuple ConvertArgs(const py::tuple &args) { + size_t tuple_size = GetTupleSize(args); + py::tuple res(tuple_size); + size_t index = 0; + for (size_t i = 0; i < args.size(); i++) { + if (py::isinstance(args[i])) { + ConvertTupleArg(&res, &index, args[i]); + } else { + res[index++] = args[i]; + } + } + return res; } -py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, - PynativeStatusCode *const status) { - MS_EXCEPTION_IF_NULL(status); - py::object result; - switch (backend_policy) { - case kMsBackendVmOnly: { - // use vm only - MS_LOG(INFO) << "RunOp use VM only backend"; - result = RunOpInVM(op_exec_info, status); - break; +void ClearPyNativeSession() { session = nullptr; } + +PynativeExecutor::~PynativeExecutor() { ClearRes(); } + +py::tuple RunOp(const py::args &args) { + auto executor = PynativeExecutor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor); + try { + return executor->RunOpInner(args); + } catch (const py::error_already_set &ex) { + executor->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(py::error_already_set(ex)); + } catch (const py::type_error &ex) { + executor->Clean(); + throw py::type_error(ex); + } catch (const py::value_error &ex) { + executor->Clean(); + throw py::value_error(ex); + } catch (const py::index_error &ex) { + executor->Clean(); + throw py::index_error(ex); + } catch (const std::exception &ex) { + executor->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(std::runtime_error(ex.what())); + } catch (...) { + executor->Clean(); + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; + } +} + +py::tuple PynativeExecutor::RunOpInner(const py::args &args) { + MS_LOG(DEBUG) << "RunOp start " << args.size(); + OpExecInfoPtr op_exec_info = nullptr; + auto prim = py::cast(args[PY_PRIM]); + auto name = py::cast(args[PY_NAME]); + op_exec_info = GenerateOpExecInfo(args); + if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { + return RunOpWithInitBackendPolicy(op_exec_info); + } + + abstract::AbstractBasePtrList args_spec_list; + std::vector op_masks; + auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); + bool is_find = false; + if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { + auto abs_list = prim_abs_list_[prim->id()]; + MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); + if (abs_list.find(args_spec_list) != abs_list.end()) { + MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name; + op_exec_info->abstract = abs_list[args_spec_list].abs; + prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); + is_find = true; } - case kMsBackendGePrior: { -#ifdef ENABLE_GE - // use GE first, use vm when GE fails - MS_LOG(INFO) << "RunOp use GE first backend"; - result = RunOpInGE(op_exec_info, status); - if (*status != PYNATIVE_SUCCESS) { - result = RunOpInVM(op_exec_info, status); - } -#endif - break; + } + + if (op_exec_info->abstract == nullptr) { + // use python infer method + if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { + PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); } - case kMsBackendMsPrior: { - // use Ms fisrt,use others when ms failed - MS_LOG(INFO) << "RunOp use Ms first backend"; - result = RunOpInMs(op_exec_info, status); - if (*status != PYNATIVE_SUCCESS) { - MS_LOG(ERROR) << "RunOp use Ms backend failed!!!"; - } - break; + } + + if (cnode != nullptr) { + cnode->set_abstract(op_exec_info->abstract); + } + + op_exec_info->inputs_mask = op_masks; + MS_EXCEPTION_IF_NULL(op_exec_info); + if (op_exec_info->abstract != nullptr) { + MS_LOG(DEBUG) << "Run op infer " << name << " " << op_exec_info->abstract->ToString(); + py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); + if (!output["value"].is_none()) { + py::tuple value_ret(1); + value_ret[0] = output["value"]; + return value_ret; + } + if (op_exec_info->py_primitive->is_const_prim()) { + py::tuple value_ret(1); + value_ret[0] = ""; + return value_ret; } - default: - MS_LOG(ERROR) << "No backend configured for run op"; } - return result; -} -ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { - auto id = GetOpId(op_exec_info); - int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); - auto op = std::to_string(graph_id) + id; - op.append(std::to_string(op_id_map_[id])); - auto iter = op_forward_map_.find(op); - if (iter != op_forward_map_.end()) { - ++op_id_map_[id]; - MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second; - return iter->second; + if (!is_find) { + // const_value need infer every step + auto &out = prim_abs_list_[prim->id()]; + out[args_spec_list].abs = op_exec_info->abstract; + out[args_spec_list].attrs = prim->evaluate_added_attrs(); + MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); } - if (!first_grad_step_) { - ++op_id_map_[id]; + + if (PynativeExecutor::GetInstance()->grad_flag()) { + op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info); + } else { + (void)GetOpId(op_exec_info); } - return nullptr; + + auto result = RunOpWithInitBackendPolicy(op_exec_info); + py::object out_real = result; + if (result.size() == 1) { + MS_LOG(DEBUG) << "Output size is 1"; + out_real = result[0]; + } + std::string obj_id = GetId(out_real); + node_abs_map_[obj_id] = op_exec_info->abstract; + PynativeExecutor::GetInstance()->SaveOutputNodeMap(obj_id, out_real, cnode); + if (cnode != nullptr) { + PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast(), result); + } + return result; } AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, @@ -751,13 +810,12 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v MS_EXCEPTION_IF_NULL(args_spec_list); MS_EXCEPTION_IF_NULL(op_exec_info); CNodePtr cnode = nullptr; - std::vector inputs; auto prim = op_exec_info->py_primitive; - const auto &signature = prim->signatures(); - - inputs.push_back(NewValueNode(prim)); + std::vector inputs; + inputs.emplace_back(NewValueNode(prim)); + const auto &signature = prim->signatures(); auto sig_size = signature.size(); auto size = op_exec_info->op_inputs.size(); // ignore signature for cast op @@ -769,7 +827,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v if (!is_cast_op) { RunParameterAutoMixPrecisionCast(op_exec_info); } - MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name; + MS_LOG(DEBUG) << "Make cnode for " << op_exec_info->op_name; for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) { const auto &obj = op_exec_info->op_inputs[i]; bool op_mask = false; @@ -779,19 +837,21 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v op_mask = meta_tensor->is_parameter(); } } - (*op_masks).push_back(op_mask); - MS_LOG(DEBUG) << "gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ " + (*op_masks).emplace_back(op_mask); + MS_LOG(DEBUG) << "Gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ " << grad_flag_; AnfNodePtr node = nullptr; abstract::AbstractBasePtr abs = nullptr; auto id = GetId(obj); - if (node_abs_map_.find(id) != node_abs_map_.end()) { - abs = node_abs_map_[id]; + auto it = node_abs_map_.find(id); + if (it != node_abs_map_.end()) { + abs = it->second; } if (!graph_info_map_.empty()) { node = GetInput(obj, op_mask); } + // update abstract if (node != nullptr && node->abstract() != nullptr) { abs = node->abstract(); } @@ -801,180 +861,155 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v bool is_const_prim = prim->is_const_prim(); MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " << prim->is_const_prim(); - bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i); + bool is_const_input = + have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end(); if (abs == nullptr || is_const_prim || is_const_input) { - MS_LOG(DEBUG) << "MakeCnode get node no in map" << id; + MS_LOG(DEBUG) << "MakeCnode get node no in map " << id; ValuePtr input_value = PyAttrValue(obj); abs = input_value->ToAbstract(); if (!is_const_prim && !is_const_input) { auto config = abstract::AbstractBase::kBroadenTensorOnly; abs = abs->Broaden(config); - MS_LOG(DEBUG) << "broaden for " << prim->ToString() << " " << config; + MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config; } node_abs_map_[id] = abs; } - (*args_spec_list).push_back(abs); - inputs.push_back(node); + (*args_spec_list).emplace_back(abs); + if (node != nullptr) { + inputs.emplace_back(node); + } } MS_LOG(DEBUG) << "MakeCnode args end"; - if (grad_flag_) { - if (curr_g_ != nullptr) { - cnode = curr_g_->NewCNode(inputs); - MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4); - } + if (grad_flag_ && curr_g_ != nullptr) { + cnode = curr_g_->NewCNode(inputs); + MS_LOG(DEBUG) << "Runop MakeCnode, new node is " << cnode->DebugString(4); } - return cnode; } -void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out_real, - const AnfNodePtr &cnode) { - if (!grad_flag_ || graph_info_map_.empty()) { - MS_LOG(DEBUG) << "no graph cnode"; - return; - } - - std::string obj_id = GetId(out_real); - MS_EXCEPTION_IF_NULL(cnode); - MS_LOG(DEBUG) << "MakeCnode set obj node id " << cnode->DebugString(4) << "id " << obj_id; - - if (py::isinstance(out_real)) { - auto value = py::cast(out_real); - if (value.size() > 1) { - for (int i = 0; i < static_cast(value.size()); i++) { - auto value_id = GetId(value[i]); - MS_LOG(DEBUG) << "MakeCnode set node id " << value_id; - set_obj_node_map(curr_g_, value_id, cnode, i); +void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) { + size_t size = op_exec_info->op_inputs.size(); + auto prim = op_exec_info->py_primitive; + const auto &signature = prim->signatures(); + for (size_t i = 0; i < size; i++) { + auto obj = op_exec_info->op_inputs[i]; + auto sig = SignatureEnumRW::kRWDefault; + if (!signature.empty()) { + sig = signature[i].rw; + } + MS_LOG(DEBUG) << "Check mix precision " << op_exec_info->op_name << " input " << i << " " + << std::string(py::repr(obj)); + // mix precision for non param + bool is_cast = false; + py::object cast_output; + if (py::isinstance(obj)) { + auto meta_tensor = obj.cast(); + if (meta_tensor && meta_tensor->is_parameter()) { + if (sig != SignatureEnumRW::kRWRead) { + continue; + } } + // redundant cast call if the tensor is a const Tensor. + cast_output = DoParamMixPrecisionCast(&is_cast, obj); + } else if (py::isinstance(obj) || py::isinstance(obj)) { + // mix precision for tuple inputs + cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj); + } + if (is_cast) { + op_exec_info->op_inputs[i] = cast_output; } } - set_obj_node_map(curr_g_, obj_id, cnode); - set_pyobj(curr_g_, obj_id); -} + std::vector dtypes; -void GenTupleMap(const ValueTuplePtr &tuple, std::map *t_map) { - if (t_map == nullptr) { - return; + bool has_dtype_sig = GetSignatureType(prim, &dtypes); + std::map dst_types; + if (has_dtype_sig) { + // fetch info for implicit cast + auto type_indexes = GetTypeIndex(dtypes); + dst_types = GetDstType(op_exec_info->op_inputs, type_indexes); } - for (size_t i = 0; i < tuple->size(); i++) { - ValuePtr tuple_i = (*tuple)[i]; - if (tuple_i->isa()) { - auto t = tuple_i->cast(); - (*t_map)[t->id()] = t; - } else if (tuple_i->isa()) { - GenTupleMap(tuple_i->cast(), t_map); - } - } - MS_LOG(DEBUG) << "End GenTupleMap" << tuple->ToString(); + MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name; + DoSignatrueCast(prim, dst_types, dtypes, op_exec_info); } -ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple) { - std::vector value_list; - for (size_t i = 0; i < tuple->size(); i++) { - ValuePtr tuple_i = (*tuple)[i]; - if (tuple_i->isa()) { - auto t = tuple_i->cast(); - auto new_tensor = std::make_shared(*t); - new_tensor->set_device_address(nullptr); - value_list.push_back(new_tensor); - } else if (tuple_i->isa()) { - value_list.push_back(CleanTupleAddr(tuple_i->cast())); - } else { - MS_LOG(DEBUG) << "in value" << tuple_i->ToString(); - value_list.push_back(tuple_i); - } - } - MS_LOG(DEBUG) << "End CleanTupleAddr"; - return std::make_shared(value_list); -} +AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { + AnfNodePtr node = nullptr; + std::string obj_id = GetId(obj); -void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value, - std::map *t_map) { - if (op_forward_map_.find(id) != op_forward_map_.end()) { - if (op_forward_map_[id]->isa()) { - // for one op have multi outputs but save only one tensor - if (value->isa()) { - auto tuple = op_forward_map_[id]->cast(); - auto value_t = value->cast(); - for (size_t i = 0; i < tuple->size(); i++) { - if ((*tuple)[i]->isa()) { - auto tuple_t = (*tuple)[i]->cast(); - if (value_t->id() == tuple_t->id()) { - tuple_t->set_device_address(value_t->device_address()); - MS_LOG(DEBUG) << "After Saveop " << tuple_t->ToString(); - break; - } - } - } - } + if (op_mask) { + MS_LOG(DEBUG) << "Cell paramsters(weights)"; + // get the parameter name from parameter object + auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name"); + if (py::isinstance(name_attr)) { + MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } - - if (value->isa() && t_map != nullptr) { - GenTupleMap(op_forward_map_[id]->cast(), t_map); + auto param_name = py::cast(name_attr); + if (graph_info_map_[df_builder_].params.find(obj_id) == graph_info_map_[df_builder_].params.end()) { + auto free_param = df_builder_->add_parameter(); + free_param->set_name(param_name); + free_param->debug_info()->set_name(param_name); + auto value = py::cast(obj); + free_param->set_default_param(value); + MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; + graph_info_map_[df_builder_].params.emplace(obj_id); + set_node_map(df_builder_, obj_id, free_param); + return free_param; } - MS_LOG(DEBUG) << "Save op forward value: " - << "(" << id << "), " << op_forward_map_[id]->ToString(); - return; + return graph_info_map_[df_builder_].node_map[obj_id].first; } - if (value->isa() && t_map == nullptr) { - // make cnode gen all tuple node and set device_address be null - op_forward_map_[id] = CleanTupleAddr(value->cast()); - } else { - op_forward_map_[id] = value; - } - MS_LOG(DEBUG) << "Save op forward value: " - << "(" << id << "), " << value->ToString(); -} - -void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { - if (!grad_flag_ || op_exec_info->value != nullptr) { - return; - } - py::object out_real = out; - if (out.size() == 1) { - out_real = out[0]; - } - auto value = PyAttrValue(out_real); - if (cnode != nullptr) { - size_t size = op_exec_info->op_inputs.size(); - for (size_t i = 0; i < size; i++) { - auto obj = op_exec_info->op_inputs[i]; - auto obj_id = GetId(obj); - if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) { - cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]); - } else { - cnode->add_input_value(nullptr, ""); - } + if (graph_info_map_[curr_g_].node_map.find(obj_id) != graph_info_map_[curr_g_].node_map.end()) { + // op(x, y) + // out = op(op1(x, y)) + // out = op(cell1(x, y)) + // out = op(cell1(x, y)[0]) + node = GetObjNode(obj, obj_id); + } else if (py::isinstance(obj)) { + // out = op((x, y)) + // out = cell((x, y)) + auto tuple = obj.cast(); + // cell((1,2)): support not mix (scalar, tensor) + if (!tuple.empty() && !py::isinstance(tuple[0])) { + return MakeValueNode(obj, obj_id); } - std::string id = GetOpId(op_exec_info); - int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); - auto op_id = std::to_string(graph_id) + id; - op_id.append(std::to_string(op_id_map_[id])); - cnode->set_forward(value, op_id); - ++op_id_map_[id]; - auto out_id = GetId(out_real); - if (py::isinstance(out_real)) { - auto tuple_item = py::cast(out_real); - for (size_t i = 0; i < tuple_item.size(); i++) { - auto tuple_item_id = GetId(tuple_item[i]); - obj_to_forward_id_[tuple_item_id] = op_id; - } - SaveOpForwardValue(op_id, value, nullptr); + std::vector args; + args.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + auto tuple_size = tuple.size(); + for (size_t i = 0; i < tuple_size; i++) { + args.emplace_back(GetInput(tuple[i], false)); } - obj_to_forward_id_[out_id] = op_id; + auto cnode = curr_g_->NewCNode(args); + set_node_map(curr_g_, GetId(obj), cnode); + node = cnode; + } else { + node = MakeValueNode(obj, obj_id); } + node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr" + : MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id; + return node; } -AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { - auto id = GetId(obj); - auto &out = graph_info_map_[curr_g_].obj_node_map[id]; +AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { + auto &out = graph_info_map_[curr_g_].node_map[obj_id]; if (out.second.size() == 1 && out.second[0] == -1) { return out.first; } + MS_LOG(DEBUG) << "Output size " << out.second.size(); + + // Params node + if (graph_info_map_[curr_g_].params.find(obj_id) != graph_info_map_[curr_g_].params.end()) { + auto para_node = out.first; + for (auto &idx : out.second) { + std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, + NewValueNode(idx)}; + para_node = curr_g_->NewCNode(tuple_get_item_inputs); + } + return para_node; + } + + // Normal node CNodePtr node = out.first->cast(); - MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString(); auto abs = node->abstract(); ValuePtr out_obj = nullptr; if (node->forward().first != nullptr) { @@ -993,46 +1028,173 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { } if (abs != nullptr && abs->isa()) { auto prim_abs = dyn_cast(abs)->elements()[idx]; - MS_LOG(DEBUG) << "set tuple getitem abs" << prim_abs->ToString(); + MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString(); node->set_abstract(prim_abs); } } if (node->abstract() != nullptr) { - node_abs_map_[id] = node->abstract(); + node_abs_map_[obj_id] = node->abstract(); } - MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); + MS_LOG(DEBUG) << "GetObjNode output " << node->DebugString(6); return node; } -AnfNodePtr PynativeExecutor::GetParamNode(const py::object &obj) { - auto id = GetId(obj); - auto ¶m = graph_info_map_[curr_g_].param_map[id]; - if (param.second.size() == 1 && param.second[0] == -1) { - return param.first; +AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) { + ValuePtr converted_ret = nullptr; + parse::ConvertData(obj, &converted_ret); + auto node = NewValueNode(converted_ret); + set_node_map(curr_g_, obj_id, node); + return node; +} + +ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { + auto id = GetOpId(op_exec_info); + int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); + auto op = std::to_string(graph_id) + id; + op.append(std::to_string(op_id_map_[id])); + auto iter = op_forward_map_.find(op); + if (iter != op_forward_map_.end()) { + ++op_id_map_[id]; + MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second; + return iter->second; } - auto para_node = param.first; - for (auto &idx : param.second) { - std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, NewValueNode(idx)}; - para_node = curr_g_->NewCNode(tuple_get_item_inputs); + if (!first_grad_step_) { + ++op_id_map_[id]; } - return para_node; + return nullptr; } -std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) { - auto cell_id = GetId(cell); - for (size_t i = 0; i < args.size(); i++) { - std::string arg_id = GetId(args[i]); - if (node_abs_map_.find(arg_id) != node_abs_map_.end()) { - cell_id += node_abs_map_[arg_id]->ToString(); +void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, + const AnfNodePtr &cnode) { + if (!grad_flag_ || graph_info_map_.empty()) { + MS_LOG(DEBUG) << "No need save output"; + return; + } + MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString(4) << " id " << obj_id; + + if (py::isinstance(out_real)) { + auto value = py::cast(out_real); + auto size = static_cast(value.size()); + if (size > 1) { + for (int i = 0; i < size; ++i) { + auto value_id = GetId(value[i]); + set_node_map(curr_g_, value_id, cnode, i); + } + } + } + set_node_map(curr_g_, obj_id, cnode); + set_pyobj(curr_g_, obj_id); +} + +void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { + if (!grad_flag_ || op_exec_info->value != nullptr || cnode == nullptr) { + return; + } + py::object out_real = out; + if (out.size() == 1) { + out_real = out[0]; + } + + auto value = PyAttrValue(out_real); + size_t size = op_exec_info->op_inputs.size(); + for (size_t i = 0; i < size; i++) { + auto obj = op_exec_info->op_inputs[i]; + auto obj_id = GetId(obj); + if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) { + cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]); } else { - auto abs = PyAttrValue(args[i])->ToAbstract(); - auto config = abstract::AbstractBase::kBroadenTensorOnly; - abs = abs->Broaden(config); - cell_id += abs->ToString(); - node_abs_map_[arg_id] = abs; + cnode->add_input_value(nullptr, ""); } } - return cell_id; + std::string id = GetOpId(op_exec_info); + int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); + auto op_id = std::to_string(graph_id) + id; + op_id.append(std::to_string(op_id_map_[id])); + cnode->set_forward(value, op_id); + ++op_id_map_[id]; + auto out_id = GetId(out_real); + if (py::isinstance(out_real)) { + auto tuple_item = py::cast(out_real); + for (size_t i = 0; i < tuple_item.size(); i++) { + auto tuple_item_id = GetId(tuple_item[i]); + obj_to_forward_id_[tuple_item_id] = op_id; + } + SaveOpForwardValue(op_id, value, nullptr); + } + obj_to_forward_id_[out_id] = op_id; +} + +void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value, + std::map *t_map) { + if (op_forward_map_.find(id) != op_forward_map_.end()) { + // for one op have multi outputs but save only one tensor + if (op_forward_map_[id]->isa() && value->isa()) { + auto tuple = op_forward_map_[id]->cast(); + auto value_t = value->cast(); + for (size_t i = 0; i < tuple->size(); i++) { + if ((*tuple)[i]->isa()) { + auto tuple_t = (*tuple)[i]->cast(); + if (value_t->id() == tuple_t->id()) { + tuple_t->set_device_address(value_t->device_address()); + MS_LOG(DEBUG) << "After Saveop " << tuple_t->ToString(); + break; + } + } + } + } + + if (value->isa() && t_map != nullptr) { + GenTupleMap(op_forward_map_[id]->cast(), t_map); + } + MS_LOG(DEBUG) << "Save op forward value: " + << "(" << id << "), " << op_forward_map_[id]->ToString(); + return; + } + + if (value->isa() && t_map == nullptr) { + // make cnode gen all tuple node and set device_address be null + op_forward_map_[id] = CleanTupleAddr(value->cast()); + } else { + op_forward_map_[id] = value; + } + MS_LOG(DEBUG) << "Save op forward value: " + << "(" << id << "), " << value->ToString(); +} + +void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map *t_map) { + if (t_map == nullptr) { + return; + } + for (size_t i = 0; i < tuple->size(); i++) { + ValuePtr tuple_i = (*tuple)[i]; + if (tuple_i->isa()) { + auto t = tuple_i->cast(); + (*t_map)[t->id()] = t; + } else if (tuple_i->isa()) { + GenTupleMap(tuple_i->cast(), t_map); + } + } + MS_LOG(DEBUG) << "End GenTupleMap " << tuple->ToString(); +} + +ValuePtr PynativeExecutor::CleanTupleAddr(const ValueTuplePtr &tuple) { + std::vector value_list; + for (size_t i = 0; i < tuple->size(); i++) { + ValuePtr tuple_i = (*tuple)[i]; + if (tuple_i->isa()) { + auto t = tuple_i->cast(); + auto new_tensor = std::make_shared(*t); + new_tensor->set_device_address(nullptr); + value_list.emplace_back(new_tensor); + } else if (tuple_i->isa()) { + value_list.emplace_back(CleanTupleAddr(tuple_i->cast())); + } else { + MS_LOG(DEBUG) << "Tuple[i] value " << tuple_i->ToString(); + value_list.emplace_back(tuple_i); + } + } + MS_LOG(DEBUG) << "End CleanTupleAddr"; + return std::make_shared(value_list); } py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) { @@ -1050,130 +1212,144 @@ py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_e return result; } -py::tuple PynativeExecutor::RunOpInner(const py::args &args) { - MS_LOG(DEBUG) << "RunOp start " << args.size(); - OpExecInfoPtr op_exec_info = nullptr; - auto prim = py::cast(args[PY_PRIM]); - auto name = py::cast(args[PY_NAME]); - abstract::AbstractBasePtrList args_spec_list; - std::vector op_masks; - op_exec_info = GenerateOpExecInfo(args); - if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { - return RunOpWithInitBackendPolicy(op_exec_info); - } - auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); - bool is_find = false; - if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { - auto abs_list = prim_abs_list_[prim->id()]; - MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); - if (abs_list.find(args_spec_list) != abs_list.end()) { - MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name; - op_exec_info->abstract = abs_list[args_spec_list].abs; - prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); - is_find = true; +MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) { + MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; + mindspore::parse::python_adapter::set_python_env_flag(true); + MsBackendPolicy backend_policy; +#if (!defined ENABLE_GE) + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!context::IsTsdOpened(ms_context)) { + if (!context::OpenTsd(ms_context)) { + MS_LOG(EXCEPTION) << "Open tsd failed"; } } - - if (op_exec_info->abstract == nullptr) { - // use python infer method - if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { - PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); - } + if (ms_context->backend_policy() == "ms") { + backend_policy = kMsBackendMsPrior; + } else { + backend_policy = kMsBackendVmOnly; } - - if (cnode != nullptr) { - cnode->set_abstract(op_exec_info->abstract); - MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString(); +#else + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + context::PynativeInitGe(ms_context); + backend_policy = kMsBackendGeOnly; +#endif + if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { + backend_policy = kMsBackendVmOnly; } + return backend_policy; +} - op_exec_info->inputs_mask = op_masks; - MS_EXCEPTION_IF_NULL(op_exec_info); - if (op_exec_info->abstract != nullptr) { - MS_LOG(DEBUG) << "run op infer" << name << op_exec_info->abstract->ToString(); - py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); - if (!output["value"].is_none()) { - py::tuple value_ret(1); - value_ret[0] = output["value"]; - return value_ret; +py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, + PynativeStatusCode *const status) { + MS_EXCEPTION_IF_NULL(status); + py::object result; + switch (backend_policy) { + case kMsBackendVmOnly: { + // use vm only + MS_LOG(INFO) << "RunOp use VM only backend"; + result = RunOpInVM(op_exec_info, status); + break; } - if (op_exec_info->py_primitive->is_const_prim()) { - py::tuple value_ret(1); - value_ret[0] = ""; - return value_ret; + case kMsBackendGePrior: { +#ifdef ENABLE_GE + // use GE first, use vm when GE fails + MS_LOG(INFO) << "RunOp use GE first backend"; + result = RunOpInGE(op_exec_info, status); + if (*status != PYNATIVE_SUCCESS) { + result = RunOpInVM(op_exec_info, status); + } +#endif + break; + } + case kMsBackendMsPrior: { + // use Ms fisrt,use others when ms failed + MS_LOG(INFO) << "RunOp use Ms first backend"; + result = RunOpInMs(op_exec_info, status); + if (*status != PYNATIVE_SUCCESS) { + MS_LOG(ERROR) << "RunOp use Ms backend failed!!!"; + } + break; } + default: + MS_LOG(ERROR) << "No backend configured for run op"; } + return result; +} - if (!is_find) { - // const_value need infer every step - auto &out = prim_abs_list_[prim->id()]; - out[args_spec_list].abs = op_exec_info->abstract; - out[args_spec_list].attrs = prim->evaluate_added_attrs(); - MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); +py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_EXCEPTION_IF_NULL(status); + MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms"; + auto ms_context = MsContext::GetInstance(); + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, true); + std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); + if (device_target != kAscendDevice && device_target != kGPUDevice) { + MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; } - if (PynativeExecutor::GetInstance()->grad_flag()) { - op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info); - } else { - (void)GetOpId(op_exec_info); + if (session == nullptr) { + session = session::SessionFactory::Get().Create(device_target); + MS_EXCEPTION_IF_NULL(session); + session->Init(ms_context->get_param(MS_CTX_DEVICE_ID)); } - auto result = RunOpWithInitBackendPolicy(op_exec_info); - py::object out_real = result; - if (result.size() == 1) { - MS_LOG(DEBUG) << "MakeCnode out size is one."; - out_real = result[0]; - } - std::string obj_id = GetId(out_real); - node_abs_map_[obj_id] = op_exec_info->abstract; - PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, out_real, cnode); - if (cnode != nullptr) { - PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast(), result); - } + std::vector input_tensors; + std::vector tensors_mask; + ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); + // get graph info for checking it whether existing in the cache + std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); + session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, op_exec_info->abstract, + op_exec_info->value}; + session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); + EraseValueNodeTensor(tensors_mask, &input_tensors); + VectorRef outputs; + session->RunOp(&op_run_info, graph_info, input_tensors, &outputs); + auto result = BaseRefToPyData(outputs); + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "End run op [" << op_exec_info->op_name << "] with backend policy ms"; return result; } -py::tuple RunOp(const py::args &args) { - try { - return PynativeExecutor::GetInstance()->RunOpInner(args); - } catch (const py::error_already_set &ex) { - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::index_error(ex); - } catch (const std::exception &ex) { - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - PynativeExecutor::GetInstance()->Clean(); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; +void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); } + +void PynativeExecutor::Popp() { + if (graph_context_.empty()) { + MS_LOG(EXCEPTION) << "Stack graph_context_ is empty"; + } + graph_context_.pop(); + if (!graph_context_.empty()) { + curr_g_ = graph_context_.top(); } } -void ClearPyNativeSession() { session = nullptr; } - -PynativeExecutor::~PynativeExecutor() { ClearRes(); } - -PynativeExecutor::PynativeExecutor() { - grad_flag_ = false; - first_grad_step_ = false; +std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) { + auto cell_id = GetId(cell); + for (size_t i = 0; i < args.size(); i++) { + std::string arg_id = GetId(args[i]); + auto it = node_abs_map_.find(arg_id); + if (it != node_abs_map_.end()) { + cell_id += it->second->ToString(); + } else { + auto abs = PyAttrValue(args[i])->ToAbstract(); + auto config = abstract::AbstractBase::kBroadenTensorOnly; + abs = abs->Broaden(config); + cell_id += abs->ToString(); + node_abs_map_[arg_id] = abs; + } + } + return cell_id; } void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { auto cell_id = GetCellId(cell, args); - // judge graph_context_.empty() to create sperate graphs except for the top - if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) { - if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { - resource_ = cell_resource_map_[cell_id]; + MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; + if (!dynamic_shape && graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end()) { + auto it = cell_resource_map_.find(cell_id); + if (it != cell_resource_map_.end()) { + resource_ = it->second; } MS_LOG(DEBUG) << "Newgraph already compiled"; return; @@ -1181,183 +1357,115 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg auto g = std::make_shared(); if (graph_context_.empty()) { - for (auto arg : args) { - if (py::isinstance(arg)) { - auto tensor = arg.cast(); - if (tensor && tensor->is_parameter()) { - MS_EXCEPTION(TypeError) << "The inputs could not be Parameter."; - } - } - } - // a df builder is built for every top function graph - df_builder_ = std::make_shared(); - df_builder_map_[cell_id] = df_builder_; - top_g_ = curr_g_ = g; - resource_ = std::make_shared(); - resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; - cell_resource_map_[cell_id] = resource_; - MS_LOG(DEBUG) << "First new graph" << top_g_.get(); - first_grad_step_ = true; - top_graph_cells_.insert(cell_id); + MakeNewTopGraph(cell_id, args, g); } else { - if (df_builder_ == nullptr) { - MS_LOG(EXCEPTION) << "In NewGraphInner, got df builder is nullptr"; - } + MS_EXCEPTION_IF_NULL(df_builder_); curr_g_ = g; } Pushp(); - if (graph_info_map_.count(g) == 0) { - graph_info_map_[g] = GraphInfo(); + if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) { + graph_info_map_.emplace(curr_g_, GraphInfo()); } - for (size_t i = 0; i < args.size(); i++) { + for (size_t i = 0; i < args.size(); ++i) { auto param = args[i]; auto new_param = g->add_parameter(); std::string param_obj = GetId(param); - if (py::isinstance(param)) { - auto tuple = param.cast(); - auto tuple_size = static_cast(tuple.size()); - for (int j = 0; j < tuple_size; j++) { - set_param_map(curr_g_, GetId(tuple[j]), new_param, j); - SetTupleParam(tuple[j], new_param, std::vector{j}); - } - } - set_param_map(curr_g_, param_obj, new_param); + set_node_map(curr_g_, param, new_param, true); + set_node_map(curr_g_, param_obj, new_param); } } -AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) { - ValuePtr converted_ret = nullptr; - parse::ConvertData(obj, &converted_ret); - auto node = NewValueNode(converted_ret); - set_obj_node_map(curr_g_, obj_id, node); - return node; -} - -AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { - AnfNodePtr node = nullptr; - std::string obj_id = GetId(obj); - - if (op_mask) { - MS_LOG(DEBUG) << "Topgraph free parameter"; - // get the parameter name from parameter object - auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name"); - if (py::isinstance(name_attr)) { - MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; - } - auto param_name = py::cast(name_attr); - if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { - auto free_param = df_builder_->add_parameter(); - free_param->set_name(param_name); - free_param->debug_info()->set_name(param_name); - auto value = py::cast(obj); - free_param->set_default_param(value); - MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; - set_param_map(df_builder_, obj_id, free_param); - return free_param; +void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) { + for (auto arg : args) { + if (py::isinstance(arg)) { + auto tensor = arg.cast(); + if (tensor && tensor->is_parameter()) { + MS_EXCEPTION(TypeError) << "The inputs could not be Parameter."; + } } - return graph_info_map_[df_builder_].param_map[obj_id].first; } - // if input is graph output - if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) { - // op(x, y) - node = GetParamNode(obj); - } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) { - // out = op(op1(x, y)) - // out = op(cell1(x, y)) - // out = op(cell1(x, y)[0]) - node = GetObjNode(obj); - } else if (py::isinstance(obj)) { - // out = op((x, y)) - // out = cell((x, y)) - auto tuple = obj.cast(); - // cell((1,2)): support not mix (scalar, tensor) - if (!tuple.empty() && !py::isinstance(tuple[0])) { - return MakeValueNode(obj, obj_id); + if (dynamic_shape) { + auto it = df_builder_map_.find(cell_id); + if (it != df_builder_map_.end()) { + df_builder_map_.erase(cell_id); } - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - auto tuple_size = static_cast(tuple.size()); - for (int i = 0; i < tuple_size; i++) { - args.push_back(GetInput(tuple[i], false)); + auto ic = cell_resource_map_.find(cell_id); + if (ic != cell_resource_map_.end()) { + cell_resource_map_.erase(cell_id); } - auto cnode = curr_g_->NewCNode(args); - set_obj_node_map(curr_g_, GetId(obj), cnode); - node = cnode; - } else { - node = MakeValueNode(obj, obj_id); } - - MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id; - return node; + // a df builder is built for every top function graph + df_builder_ = std::make_shared(); + df_builder_map_.emplace(cell_id, std::make_pair(df_builder_, nullptr)); + top_g_ = curr_g_ = g; + resource_ = std::make_shared(); + resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; + cell_resource_map_.emplace(cell_id, resource_); + MS_LOG(DEBUG) << "New top graph for " << cell_id; + first_grad_step_ = true; + top_graph_cells_.emplace(cell_id); } -// for output[0][1] need getitem multi -void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx) { - if (py::isinstance(obj)) { - auto tuple = obj.cast(); - for (int i = 0; i < static_cast(tuple.size()); i++) { - std::vector tmp = idx; - tmp.push_back(i); - set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, tmp); - SetTupleOutput(tuple[i], cnode, tmp); - } +void PynativeExecutor::set_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, + bool is_param) { + if (!py::isinstance(node)) { + return; } -} - -// for param ((a, (b, c)), d) need multi getitem -void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr ¶_node, std::vector idx) { - if (py::isinstance(obj)) { - auto tuple = obj.cast(); - for (int i = 0; i < static_cast(tuple.size()); i++) { - std::vector tmp = idx; - tmp.push_back(i); - set_param_map(curr_g_, GetId(tuple[i]), para_node, tmp); - SetTupleParam(tuple[i], para_node, tmp); + auto tuple = node.cast(); + auto tuple_size = static_cast(tuple.size()); + for (int i = 0; i < tuple_size; ++i) { + auto id = GetId(tuple[i]); + if (is_param) { + graph_info_map_[g].params.emplace(id); } + set_node_map(g, id, cnode, i); + set_tuple_node_map(g, tuple[i], cnode, std::vector{i}, is_param); } } -void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); } - -void PynativeExecutor::Popp() { - if (graph_context_.empty()) { - MS_LOG(EXCEPTION) << "Stack graph_context_ is empty"; +void PynativeExecutor::set_tuple_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, + const std::vector &idx, bool is_param) { + if (!py::isinstance(node)) { + return; } - graph_context_.pop(); - if (!graph_context_.empty()) { - curr_g_ = graph_context_.top(); + auto tuple = node.cast(); + auto tuple_size = static_cast(tuple.size()); + for (int i = 0; i < tuple_size; ++i) { + std::vector tmp = idx; + tmp.emplace_back(i); + auto id = GetId(tuple[i]); + if (is_param) { + graph_info_map_[g].params.emplace(id); + } + set_node_map(g, id, cnode, tmp); + set_tuple_node_map(g, tuple[i], cnode, tmp, is_param); } } void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { auto cell_id = GetCellId(cell, args); - if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) { + MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; + if (!dynamic_shape && graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end()) { MS_LOG(DEBUG) << "Endgraph already compiled"; return; } - cell_graph_map_[cell_id] = curr_g_; + cell_graph_map_.emplace(cell_id, std::make_pair(curr_g_, false)); auto out_id = GetId(out); - if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { - // cell construct return x, y + // x =op1, y =op2, return (x, y) + if (graph_info_map_[curr_g_].node_map.find(out_id) == graph_info_map_[curr_g_].node_map.end()) { if (py::isinstance(out)) { - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - auto tuple = out.cast(); - MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size(); auto tuple_size = static_cast(tuple.size()); - auto cnode = curr_g_->NewCNode(args); - for (int i = 0; i < tuple_size; i++) { - args.push_back(GetInput(tuple[i], false)); - } - cnode->set_inputs(args); + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); for (int i = 0; i < tuple_size; i++) { - set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i); - SetTupleOutput(tuple[i], cnode, std::vector{i}); + inputs.emplace_back(GetInput(tuple[i], false)); } - set_obj_node_map(curr_g_, out_id, cnode); + auto cnode = curr_g_->NewCNode(inputs); + set_node_map(curr_g_, out, cnode); + set_node_map(curr_g_, out_id, cnode); } else { MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id; MakeValueNode(out, out_id); @@ -1368,17 +1476,42 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args) { - AnfNodePtr output_node; - if (graph_info_map_[curr_g_].param_map.count(out_id)) { - output_node = GetParamNode(out); - } else { - output_node = GetObjNode(out); - } + AnfNodePtr output_node = GetObjNode(out, out_id); curr_g_->set_output(output_node); - std::vector inputs; - inputs.push_back(NewValueNode(curr_g_)); - MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); + MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString(); resource_->manager()->AddFuncGraph(curr_g_); + + auto newfg = MakeGradGraph(cell, args); + + if (graph_context_.size() > 1) { + std::vector inputs; + inputs.emplace_back(NewValueNode(curr_g_)); + + Popp(); + // connect the previous graph to the inside graph + auto graph_prev = graph_context_.top(); + for (size_t i = 0; i < args.size(); i++) { + auto input = GetInput(args[i], false); + inputs.emplace_back(input); + } + auto out_cnode = graph_prev->NewCNode(inputs); + set_pyobj(graph_prev, GetCellId(cell, args)); + set_node_map(graph_prev, out, out_cnode); + set_node_map(graph_prev, GetId(out), out_cnode); + } else { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { + DumpIR("before_resolve.ir", newfg); + } + parse::ResolveFuncGraph(newfg, resource_); + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { + DumpIR("after_resolve.ir", newfg); + } + resource_->set_func_graph(newfg); + Popp(); + } +} + +FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::args &args) { // custom bprop debug bool need_replace_param = false; if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { @@ -1388,14 +1521,17 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number << " parameters that is not supported in the net."; } - MS_LOG(DEBUG) << "Use cell custom bprop function."; + MS_LOG(INFO) << "Use cell custom bprop function."; FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell); if (bprop_graph != nullptr) { - (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); - (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); + (void)curr_g_->transforms().emplace(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); + (void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(curr_g_))); } } + // Obtain grad graph auto newfg = ad::Grad(curr_g_, resource_, graph_context_.size() == 1); + graph_info_map_.erase(curr_g_); + if (need_replace_param) { auto params = newfg->parameters(); auto manager = Manage({newfg}, false); @@ -1409,68 +1545,189 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje manager->Replace(params[i], v_node); } } - graph_info_map_.erase(curr_g_); - if (graph_context_.size() > 1) { - Popp(); - // connect the previous graph to the inside graph - auto graph_prev = graph_context_.top(); - for (size_t i = 0; i < args.size(); i++) { - auto input = GetInput(args[i], false); - inputs.push_back(input); - } - auto out_cnode = graph_prev->NewCNode(inputs); - set_pyobj(graph_prev, GetCellId(cell, args)); - if (py::isinstance(out)) { - auto out_list = py::cast(out); - auto out_size = static_cast(out_list.size()); - for (int i = 0; i < out_size; i++) { - set_obj_node_map(graph_prev, GetId(out_list[i]), out_cnode, i); - SetTupleOutput(out_list[i], out_cnode, std::vector{i}); - } + return newfg; +} + +void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args) { + MS_LOG(INFO) << "GradNet start " << args.size(); + auto size = args.size(); + std::pair sens_weights_changed(false, false); + std::string cell_id = CheckCellChanged(grad, cell, weights, args, &sens_weights_changed); + MS_LOG(DEBUG) << "GradNetInner cell_id " << cell_id; + + if (!dynamic_shape && !sens_weights_changed.first && !sens_weights_changed.second && + cell_graph_map_.find(cell_id) != cell_graph_map_.end() && cell_graph_map_[cell_id].second) { + if (cell_resource_map_.find(cell_id) == cell_resource_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find resource"; + } + resource_ = cell_resource_map_[cell_id]; + MS_EXCEPTION_IF_NULL(resource_); + MS_LOG(INFO) << "GradNetInner already compiled"; + return; + } + + // set all params(input+weights) + SetGradGraphParams(size, cell_id, sens_weights_changed); + + // get params(weights) require derivative + auto w_args = GetWeightsArgs(weights); + + // get the parameters items and add the value to args_spec + auto args_spec = GetArgsSpec(args); + resource_->set_args_spec(args_spec); + MS_LOG(DEBUG) << "Args_spec size " << args_spec.size(); + + // Only need to set it at first time + if (df_builder_map_[cell_id].second == nullptr) { + auto cloned_df_builder = BasicClone(df_builder_); + auto cloned_df_newfg = BasicClone(resource_->func_graph()); + df_builder_map_[cell_id] = std::make_pair(cloned_df_builder, cloned_df_newfg); + } else { + resource_->set_func_graph(df_builder_map_[cell_id].second); + } + + // get real grad graph + GradGraph(resource_->func_graph(), grad, w_args, size); + + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { + DumpIR("befor_grad.ir", resource_->func_graph()); + DumpIR("after_grad.ir", df_builder_); + } + + resource_->set_func_graph(df_builder_); + resource_->manager()->KeepRoots({df_builder_}); + resource_->results()[pipeline::kBackend] = compile::CreateBackend(); + + MS_LOG(DEBUG) << "Start opt"; + PynativeOptimizeAction(resource_); + TaskEmitAction(resource_); + ExecuteAction(resource_); + cell_graph_map_[cell_id].second = true; + + resource_->Clean(); + ad::CleanRes(); + pipeline::ReclaimOptimizer(); +} + +std::string PynativeExecutor::CheckCellChanged(const GradOperationPtr &grad, const py::object &cell, + const py::object &weights, const py::args &args, + std::pair *sens_weights_changed) { + MS_EXCEPTION_IF_NULL(sens_weights_changed); + auto fn = [](const py::object &arg) { + std::string arg_id; + if (py::isinstance(arg)) { + auto tensor_ptr = py::cast(arg); + auto dtype = tensor_ptr->data_type(); + auto shape = tensor_ptr->shape(); + std::stringstream ss; + std::for_each(shape.begin(), shape.end(), [&ss](int i) { ss << i; }); + arg_id = ss.str() + std::to_string(dtype); + } else { + arg_id = std::string(py::str(arg)); } - set_obj_node_map(graph_prev, GetId(out), out_cnode); + return arg_id; + }; + + std::string sens_id = "sens"; + std::string cell_id; + if (grad->sens_param()) { + size_t size = args.size(); + size_t forward_args_count = size; + if (size >= 1) { + forward_args_count = size - 1; + const py::object &sens = args[forward_args_count]; + sens_id = fn(sens); + } + py::tuple forward_args(forward_args_count); + for (size_t i = 0; i < forward_args_count; ++i) { + forward_args[i] = args[i]; + } + cell_id = GetCellId(cell, forward_args); } else { - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("before_resolve.ir", newfg); - } - parse::ResolveFuncGraph(newfg, resource_); - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("after_resolve.ir", newfg); - } - resource_->set_func_graph(newfg); - Popp(); + cell_id = GetCellId(cell, args); + } + + std::string wigths_id = fn(weights); + + // Check whether sens or weights changed + auto it = cell_sw_map_.find(cell_id); + if (it != cell_sw_map_.end() && it->second.first != sens_id) { + MS_LOG(DEBUG) << "Sens_id, cur is " << it->second.first << " new is " << sens_id; + (*sens_weights_changed).first = true; + } + if (it != cell_sw_map_.end() && it->second.second != wigths_id) { + MS_LOG(DEBUG) << "Wigths_id, cur is " << it->second.first << " new is " << wigths_id; + (*sens_weights_changed).second = true; + } + cell_sw_map_[cell_id] = std::make_pair(sens_id, wigths_id); + return cell_id; +} + +void PynativeExecutor::SetGradGraphParams(size_t size, const std::string &cell_id, + const std::pair &sens_weights_changed) { + auto ic = cell_resource_map_.find(cell_id); + if (ic == cell_resource_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find resource"; } + MS_EXCEPTION_IF_NULL(ic->second); + resource_ = ic->second; + + auto it = df_builder_map_.find(cell_id); + if (it == df_builder_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find df_builder"; + } + MS_EXCEPTION_IF_NULL(it->second.first); + df_builder_ = it->second.first; + + top_g_ = cell_graph_map_[cell_id].first; + if (sens_weights_changed.first) { + MS_LOG(INFO) << "Sens changed, no need reset df_builder params"; + return; + } + + std::vector new_params; + for (size_t i = 0; i < size; i++) { + ParameterPtr p = std::make_shared(df_builder_); + new_params.emplace_back(p); + } + MS_LOG(DEBUG) << "GradNet weight param size " << df_builder_->parameters().size(); + // df_builder_->parameters() set in GetInput, which are weights params + new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); + df_builder_->set_parameters(new_params); + resource_->manager()->SetParameters(df_builder_, new_params); } std::vector PynativeExecutor::GetWeightsArgs(const py::object &weights) { std::vector w_args; - if (py::hasattr(weights, "__parameter_tuple__")) { - auto tuple = weights.cast(); - MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size(); - w_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t it = 0; it < tuple.size(); ++it) { - auto param = tuple[it]; - auto param_id = GetId(param); - AnfNodePtr para_node = nullptr; - if (graph_info_map_[df_builder_].param_map.count(param_id)) { - para_node = graph_info_map_[df_builder_].param_map[param_id].first; - } else { - auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name"); - if (py::isinstance(name_attr)) { - MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; - } - auto param_name = py::cast(name_attr); - auto free_param = df_builder_->add_parameter(); - free_param->set_name(param_name); - auto value = py::cast(param); - free_param->set_default_param(value); - free_param->debug_info()->set_name(param_name); - para_node = free_param; + if (!py::hasattr(weights, "__parameter_tuple__")) { + MS_LOG(DEBUG) << "No paramter_tuple get"; + return {}; + } + auto tuple = weights.cast(); + MS_LOG(DEBUG) << "GradNet start weights tuple size " << tuple.size(); + w_args.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t it = 0; it < tuple.size(); ++it) { + auto param = tuple[it]; + auto param_id = GetId(param); + AnfNodePtr para_node = nullptr; + if (graph_info_map_[df_builder_].params.find(param_id) != graph_info_map_[df_builder_].params.end() && + graph_info_map_[df_builder_].node_map.find(param_id) != graph_info_map_[df_builder_].node_map.end()) { + para_node = graph_info_map_[df_builder_].node_map[param_id].first; + } else { + auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name"); + if (py::isinstance(name_attr)) { + MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } - w_args.push_back(para_node); + auto param_name = py::cast(name_attr); + auto free_param = df_builder_->add_parameter(); + free_param->set_name(param_name); + auto value = py::cast(param); + free_param->set_default_param(value); + free_param->debug_info()->set_name(param_name); + para_node = free_param; } - } else { - MS_LOG(DEBUG) << "training not paramter_tuple"; + w_args.emplace_back(para_node); } return w_args; } @@ -1478,6 +1735,7 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) { abstract::AbstractBasePtrList args_spec; std::size_t size = args.size(); + // input params for (std::size_t i = 0; i < size; i++) { ValuePtr converted = nullptr; bool succ = parse::ConvertData(args[i], &converted); @@ -1486,110 +1744,91 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args } bool broaden = true; auto abs = abstract::FromValue(converted, broaden); - args_spec.push_back(abs); + args_spec.emplace_back(abs); auto param_node = std::static_pointer_cast(df_builder_->parameters()[i]); param_node->set_abstract(abs); } - + // weights params for (const auto ¶m : df_builder_->parameters()) { auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { ValuePtr value = param_node->default_param(); auto ptr = value->ToAbstract(); - if (ptr == nullptr) { - MS_LOG(EXCEPTION) << "Args convert error"; - } - args_spec.push_back(ptr); + MS_EXCEPTION_IF_NULL(ptr); + args_spec.emplace_back(ptr); param_node->set_abstract(ptr); } } - return args_spec; } -void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args) { - MS_LOG(INFO) << "GradNet start" << args.size(); - std::size_t size = args.size(); - std::string cell_id = GetCellId(cell, args); - if (graph_map_.count(cell_id) != 0) { - MS_LOG(DEBUG) << "GradNet already compiled"; - return; - } - size_t forward_args_count = args.size(); - if (grad->sens_param()) { - forward_args_count = forward_args_count - 1; - } - py::tuple forward_args(forward_args_count); - for (size_t i = 0; i < forward_args_count; i++) { - forward_args[i] = args[i]; - } - std::string forward_cell_id = GetCellId(cell, forward_args); - MS_LOG(DEBUG) << "Forward cell_id:" << forward_cell_id; - if (df_builder_map_.find(forward_cell_id) == df_builder_map_.end()) { - MS_LOG(EXCEPTION) << "Cannot find df builder"; - } - df_builder_ = df_builder_map_[forward_cell_id]; - if (df_builder_ == nullptr) { - MS_LOG(EXCEPTION) << "Got unexpected null df builder"; - } +void PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, + const std::vector &weights, size_t arg_size) { + auto nparam = top_g_->parameters().size(); + std::ostringstream ss; + ss << "grad{" << nparam << "}"; + df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true); + df_builder_->debug_info()->set_name(ss.str()); - if (cell_resource_map_.find(forward_cell_id) == cell_resource_map_.end()) { - MS_LOG(EXCEPTION) << "Cannot find resource for " << forward_cell_id; + auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights); + std::vector inputs = {NewValueNode(df)}; + for (size_t i = 0; i < arg_size; ++i) { + inputs.emplace_back(df_builder_->parameters()[i]); } - MS_LOG(DEBUG) << "GradNet first compiled"; - resource_ = cell_resource_map_[forward_cell_id]; + auto out = df_builder_->NewCNode(inputs); + df_builder_->set_output(out); + resource_->manager()->AddFuncGraph(df); + resource_->manager()->AddFuncGraph(df_builder_); +} - std::vector new_params; - for (size_t i = 0; i < size; i++) { - ParameterPtr p = std::make_shared(df_builder_); - new_params.push_back(p); +py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) { + BaseRef ret = false; + if (!grad_is_running) { + MS_LOG(DEBUG) << "Grad not running yet"; + return BaseRefToPyData(ret); } - MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); - new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); - df_builder_->set_parameters(new_params); - resource_->manager()->SetParameters(df_builder_, new_params); - - std::vector w_args = GetWeightsArgs(weights); - MS_EXCEPTION_IF_NULL(resource_->func_graph()); - if (cell_graph_map_.find(forward_cell_id) == cell_graph_map_.end()) { - MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id; + auto cell_id = GetCellId(cell, args); + string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); + MS_LOG(DEBUG) << "Key is " << key; + for (auto it = cell_graph_map_.begin(); it != cell_graph_map_.end(); ++it) { + MS_LOG(DEBUG) << "Cur cell id " << it->first; + if (key != it->first.substr(0, std::min(PTR_LEN, it->first.size()))) { + continue; + } + MS_LOG(DEBUG) << "Delete cellid from cell_graph_map_"; + cell_graph_map_.erase(it->first); + ret = true; + break; } - top_g_ = cell_graph_map_[forward_cell_id]; - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("before_grad.ir", resource_->func_graph()); + return BaseRefToPyData(ret); +} + +py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { + VectorRef arg_list; + py::tuple converted_args = ConvertArgs(args); + pipeline::ProcessVmArgInner(converted_args, resource_, &arg_list); + if (resource_->results().find(pipeline::kOutput) == resource_->results().end()) { + MS_LOG(EXCEPTION) << "Can't find run graph output"; } - auto g = GradGraph(resource_->func_graph(), grad, w_args, size); - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("after_grad.ir", g); + if (!resource_->results()[pipeline::kOutput].is()) { + MS_LOG(EXCEPTION) << "Run graph is not VmEvalFuncPtr"; } - resource_->set_func_graph(g); - resource_->manager()->KeepRoots({g}); - - // get the parameters items and add the value to args_spec - abstract::AbstractBasePtrList args_spec = GetArgsSpec(args); - MS_LOG(DEBUG) << "Args_spec size" << args_spec.size(); - - resource_->set_args_spec(args_spec); - MS_LOG(DEBUG) << "Start opt"; - - // Create backend and session - resource_->results()[pipeline::kBackend] = compile::CreateBackend(); + compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast(); + MS_EXCEPTION_IF_NULL(run); - graph_map_[cell_id] = g; - PynativeOptimizeAction(resource_); - TaskEmitAction(resource_); - ExecuteAction(resource_); - resource_->Clean(); - ad::CleanRes(); - pipeline::ReclaimOptimizer(); + std::string backend = MsContext::GetInstance()->backend_policy(); + MS_LOG(DEBUG) << "Eval run " << backend; + grad_is_running = true; + BaseRef value = (*run)(arg_list); + grad_is_running = false; + MS_LOG(DEBUG) << "Run end " << value.ToString(); + return BaseRefToPyData(value); } template void MapClear(T *map, const std::string &flag) { for (auto it = map->begin(); it != map->end();) { if (it->first.find(flag) != std::string::npos) { - it->second = nullptr; it = map->erase(it); } else { it++; @@ -1599,11 +1838,10 @@ void MapClear(T *map, const std::string &flag) { void PynativeExecutor::Clear(const std::string &flag) { if (!flag.empty()) { - MS_LOG(DEBUG) << "Clear res"; - MapClear>(&graph_map_, flag); - MapClear>(&cell_graph_map_, flag); + MS_LOG(DEBUG) << "Clear cell res"; + MapClear>>(&cell_graph_map_, flag); MapClear>(&cell_resource_map_, flag); - MapClear>(&df_builder_map_, flag); + MapClear>>(&df_builder_map_, flag); // Maybe exit in the pynative runing op, so need reset pynative flag. auto ms_context = MsContext::GetInstance(); @@ -1639,96 +1877,13 @@ void PynativeExecutor::Clean() { pipeline::ReclaimOptimizer(); } -template -void MapErase(T *map) { - for (auto it = map->begin(); it != map->end();) { - it = map->erase(it++); - } -} - void PynativeExecutor::ClearRes() { - MapErase>(&graph_map_); - MapErase>(&cell_graph_map_); - MapErase>(&cell_resource_map_); - MapErase>(&node_abs_map_); + MS_LOG(DEBUG) << "PynativeExecutor destruct"; + df_builder_map_.clear(); + cell_graph_map_.clear(); + cell_resource_map_.clear(); + node_abs_map_.clear(); Clean(); - resource_.reset(); -} - -size_t GetTupleSize(const py::tuple &args) { - size_t count = 0; - for (size_t i = 0; i < args.size(); i++) { - if (py::isinstance(args[i])) { - count += GetTupleSize(args[i]); - } else { - count += 1; - } - } - return count; -} - -void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) { - for (size_t i = 0; i < arg.size(); i++) { - if (py::isinstance(arg[i])) { - ConvertTupleArg(res, index, arg[i]); - } else { - (*res)[(*index)++] = arg[i]; - } - } -} - -py::tuple ConvertArgs(const py::tuple &args) { - size_t tuple_size = GetTupleSize(args); - py::tuple res(tuple_size); - size_t index = 0; - for (size_t i = 0; i < args.size(); i++) { - if (py::isinstance(args[i])) { - ConvertTupleArg(&res, &index, args[i]); - } else { - res[index++] = args[i]; - } - } - return res; -} - -py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { - VectorRef arg_list; - py::tuple converted_args = ConvertArgs(args); - pipeline::ProcessVmArgInner(converted_args, resource_, &arg_list); - if (resource_->results().find(pipeline::kOutput) == resource_->results().end() || - !resource_->results()[pipeline::kOutput].is()) { - MS_LOG(EXCEPTION) << "Can't find run graph func for "; - } - compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast(); - if (run == nullptr) { - MS_LOG(EXCEPTION) << "Can't find run graph func for "; - } - - std::string backend = MsContext::GetInstance()->backend_policy(); - MS_LOG(DEBUG) << "Eval run" << backend; - BaseRef value = (*run)(arg_list); - MS_LOG(DEBUG) << "Run end" << value.ToString(); - return BaseRefToPyData(value); -} - -FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, - const std::vector &weights, size_t arg_size) { - auto nparam = top_g_->parameters().size(); - std::ostringstream ss; - ss << "grad{" << nparam << "}"; - df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true); - df_builder_->debug_info()->set_name(ss.str()); - - auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights); - std::vector inputs = {NewValueNode(df)}; - for (size_t i = 0; i < arg_size; ++i) { - inputs.push_back(df_builder_->parameters()[i]); - } - auto out = df_builder_->NewCNode(inputs); - df_builder_->set_output(out); - resource_->manager()->AddFuncGraph(df); - resource_->manager()->AddFuncGraph(df_builder_); - return df_builder_; } void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { @@ -1744,86 +1899,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); } -MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) { - MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; - mindspore::parse::python_adapter::set_python_env_flag(true); - MsBackendPolicy backend_policy; -#if (!defined ENABLE_GE) - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (!context::IsTsdOpened(ms_context)) { - if (!context::OpenTsd(ms_context)) { - MS_LOG(EXCEPTION) << "Open tsd failed"; - } - } - if (ms_context->backend_policy() == "ms") { - backend_policy = kMsBackendMsPrior; - } else { - backend_policy = kMsBackendVmOnly; - } -#else - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - context::PynativeInitGe(ms_context); - backend_policy = kMsBackendGeOnly; -#endif - if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { - backend_policy = kMsBackendVmOnly; - } - return backend_policy; -} - -void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) { - size_t size = op_exec_info->op_inputs.size(); - auto prim = op_exec_info->py_primitive; - const auto &signature = prim->signatures(); - auto sig_size = signature.size(); - for (size_t i = 0; i < size; i++) { - auto obj = op_exec_info->op_inputs[i]; - auto sig = SignatureEnumRW::kRWDefault; - if (sig_size > 0) { - sig = signature[i].rw; - } - MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " " - << std::string(py::repr(obj)); - // mix precision for non param - bool is_cast = false; - py::object cast_output; - if (py::isinstance(obj)) { - auto meta_tensor = obj.cast(); - if (meta_tensor && meta_tensor->is_parameter()) { - if (sig != SignatureEnumRW::kRWRead) { - continue; - } - } - // redundant cast call if the tensor is a const Tensor. - cast_output = DoParamMixPrecisionCast(&is_cast, obj); - } else if (py::isinstance(obj) || py::isinstance(obj)) { - // mix precision for tuple inputs - cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj); - } - if (is_cast) { - op_exec_info->op_inputs[i] = cast_output; - } - } - std::vector dtypes; - - bool has_dtype_sig = GetSignatureType(prim, &dtypes); - std::map dst_types; - if (has_dtype_sig) { - // fetch info for implicit cast - auto type_indexes = GetTypeIndex(dtypes); - dst_types = GetDstType(op_exec_info->op_inputs, type_indexes); - } - MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name; - DoSignatrueCast(prim, dst_types, dtypes, op_exec_info); -} - REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { (void)py::class_>(*m, "PynativeExecutor_") .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.") .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.") + .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.") .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.") .def("clear", &PynativeExecutor::Clear, "pynative clear status.") .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""), diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index c235265533..ac55fad774 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -61,113 +62,134 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tupl void ClearPyNativeSession(); struct GraphInfo { - std::unordered_map>> param_map; - std::unordered_map>> obj_node_map; + std::unordered_set params; // hold inpout parameters and cell weigths + std::unordered_map>> node_map; AnfNodePtr output; std::vector objects; }; class PynativeExecutor : public std::enable_shared_from_this { - private: - MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info); - py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info); - AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, - abstract::AbstractBasePtrList *args_spec_list); - void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info); - public: static std::shared_ptr GetInstance() { std::lock_guard i_lock(instance_lock_); if (executor_ == nullptr) { executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); - resource_ = std::make_shared(); } return executor_; } + ~PynativeExecutor(); + + bool grad_flag() { return grad_flag_; } + void set_grad_flag(bool flag) { grad_flag_ = flag; } + + py::tuple RunOpInner(const py::args &args); void NewGraph(const py::object &cell, const py::args &args); - void NewGraphInner(const py::object &cell, const py::args &args); + py::object Run(const py::tuple &args, const py::object &phase); + py::object CheckGraph(const py::object &cell, const py::args &args); void EndGraph(const py::object &cell, const py::object &out, const py::args &args); - void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); - void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); - std::vector GetWeightsArgs(const py::object &weights); - abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); - void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args); + void SaveOpForwardValue(const std::string &id, const ValuePtr &value, + std::map *t_map); + + // Call by python void Clear(const std::string &flag = ""); + // Abnormal existed void Clean(); + // Destrcut call void ClearRes(); - bool grad_flag() { return grad_flag_; } - void set_grad_flag(bool flag) { grad_flag_ = flag; } + + private: + PynativeExecutor() = default; + PynativeExecutor(const PynativeExecutor &) = delete; + PynativeExecutor &operator=(const PynativeExecutor &) = delete; + + // run op AnfNodePtr GetInput(const py::object &obj, bool op_mask); - AnfNodePtr GetObjNode(const py::object &obj); - AnfNodePtr GetParamNode(const py::object &obj); - std::string GetCellId(const py::object &obj, const py::args &args); - FuncGraphPtr curr_g() { return curr_g_; } - void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{-1}); - } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{index}); - } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); - } + MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info); + py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info); + void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info); + py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); + py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, + PynativeStatusCode *const status); + AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); + AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, + abstract::AbstractBasePtrList *args_spec_list); + void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode); - void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { - graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector{-1}); - } - void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { - graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector{index}); - } - void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { - graph_info_map_[g].param_map[obj] = std::make_pair(node, index); - } - void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode); + // replace for grad graph + ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple); ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); - void SaveOpForwardValue(const std::string &id, const ValuePtr &value, - std::map *t_map); void SaveForwardResult(const CNodePtr &cnode, const py::object &out); + void GenTupleMap(const ValueTuplePtr &tuple, std::map *t_map); void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); - py::object Run(const py::tuple &args, const py::object &phase); - + // construct grad graph void Pushp(); void Popp(); - FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector &weights, - size_t arg_size); - void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx); - void SetTupleParam(const py::object &obj, const AnfNodePtr ¶_node, std::vector idx); - AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); - py::tuple RunOpInner(const py::args &args); + void NewGraphInner(const py::object &cell, const py::args &args); + void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g); + void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); + void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); + FuncGraphPtr MakeGradGraph(const py::object &cell, const py::args &args); + void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args); + std::string GetCellId(const py::object &obj, const py::args &args); + std::string CheckCellChanged(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args, std::pair *sens_weights_changed); + void SetGradGraphParams(size_t size, const std::string &cell_id, const std::pair &sens_weights_changed); + void GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector &weights, + size_t arg_size); + std::vector GetWeightsArgs(const py::object &weights); + abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); - ~PynativeExecutor(); + // hold graph(forward and grad) info + void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } + void set_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, bool is_param = false); + void set_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { + graph_info_map_[g].node_map[obj] = std::make_pair(node, std::vector{-1}); + } + void set_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { + graph_info_map_[g].node_map[obj] = std::make_pair(node, std::vector{index}); + } + void set_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { + graph_info_map_[g].node_map[obj] = std::make_pair(node, index); + } + void set_tuple_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, + const std::vector &idx, bool is_param = false); - private: - PynativeExecutor(); static std::shared_ptr executor_; static std::mutex instance_lock_; - static ResourcePtr resource_; static int graph_id_; - bool grad_flag_; - bool first_grad_step_; - std::unordered_map graph_map_; - std::unordered_map cell_graph_map_; - std::unordered_map cell_resource_map_; + bool grad_flag_{false}; + bool first_grad_step_{false}; + bool grad_is_running{false}; + bool dynamic_shape{false}; + + // Used for construct grad graph + FuncGraphPtr top_g_{nullptr}; + FuncGraphPtr curr_g_{nullptr}; + FuncGraphPtr df_builder_{nullptr}; + ResourcePtr resource_{nullptr}; + // Records forwrad graph, the bottom is top graph + std::stack graph_context_; + std::unordered_set top_graph_cells_; + + // record all info of a graph std::unordered_map graph_info_map_; + std::unordered_map cell_resource_map_; + std::unordered_map> cell_graph_map_; + // key: cell_id, value: (send_id, weigths_id), cache for sens and weight change + std::unordered_map> cell_sw_map_; + // key: cell_id, value: (forward graph, grad graph) + std::unordered_map> df_builder_map_; + + // used for runop and replace forward result of grad graph std::unordered_map op_forward_map_; std::unordered_map op_id_map_; std::unordered_map obj_to_forward_id_; std::unordered_map node_abs_map_; - std::unordered_map df_builder_map_; - // the stack that records the context of graph created, the bottom is the top graph - std::stack graph_context_; - FuncGraphPtr top_g_; - FuncGraphPtr df_builder_; - FuncGraphPtr curr_g_; std::unordered_map prim_abs_list_; - std::set top_graph_cells_; }; using PynativeExecutorPtr = std::shared_ptr; diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 6643993614..1ae058e1a6 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -298,6 +298,9 @@ class _PynativeExecutor: def end_graph(self, obj, output, *args, **kwargs): self._executor.end_graph(obj, output, *args, *(kwargs.values())) + def check_graph(self, obj, *args, **kwargs): + return self._executor.check_graph(obj, *args, *(kwargs.values())) + def grad(self, grad, obj, weights, *args, **kwargs): self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values())) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 5c944a7533..65e85b8890 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -244,7 +244,8 @@ class Cell(Cell_): raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name)) def __del__(self): - _pynative_exec.clear(str(id(self))) + if context.get_context("mode") == context.PYNATIVE_MODE: + _pynative_exec.clear(str(id(self))) if hasattr(self, "_create_time"): _executor.del_net_res(str(self._create_time)) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 11e670909b..8ce301dc3e 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -337,6 +337,9 @@ class GradOperation(GradOperation_): else: @_wrap_func def after_grad(*args, **kwargs): + if _pynative_exec.check_graph(fn, *args, **kwargs): + print("Another grad step is running") + fn.already_run = False self._pynative_forward_run(args, kwargs, fn) _pynative_exec.grad(grad_, fn, weights, *args, **kwargs) out = _pynative_exec(*args, **kwargs) diff --git a/tests/ut/python/pynative_mode/test_stop_gradient.py b/tests/ut/python/pynative_mode/test_stop_gradient.py index 57adfb6fa4..b045abdcbc 100644 --- a/tests/ut/python/pynative_mode/test_stop_gradient.py +++ b/tests/ut/python/pynative_mode/test_stop_gradient.py @@ -260,7 +260,7 @@ def test_stop_gradient_4(): def stop_test(x): return stop_gradient(x) - assert grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,) + assert grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,) def test_stop_gradient_5():