From bb6148661f3122271709af019fe566a130806b2d Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Tue, 17 Nov 2020 16:47:08 +0800 Subject: [PATCH] change mixedprecision of pynative --- .../kernel_compiler/kernel_build_info.h | 2 +- .../ccsrc/backend/session/session_basic.cc | 4 + .../ccsrc/backend/session/session_basic.h | 3 + mindspore/ccsrc/pipeline/pynative/base.h | 3 + .../pipeline/pynative/pynative_execute.cc | 365 +++++++++--------- .../pipeline/pynative/pynative_execute.h | 36 +- .../device/ascend/kernel_select_ascend.cc | 58 ++- .../device/ascend/kernel_select_ascend.h | 2 +- .../ascend/kernel_select_graph_kernel.cc | 4 +- mindspore/ccsrc/utils/utils.h | 2 + 10 files changed, 267 insertions(+), 212 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h index 6c4ee24323..aa8496196b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h @@ -124,7 +124,7 @@ class KernelBuildInfo::KernelBuildInfoBuilder { SetKernelType(kernel_build_info->kernel_type()); SetFusionType(kernel_build_info->fusion_type()); SetProcessor(kernel_build_info->processor()); - OpPattern(kernel_build_info->op_pattern()); + SetOpPattern(kernel_build_info->op_pattern()); for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) { kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index)); kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index)); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index da25eb0b10..8a5ae070a6 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1396,6 +1396,10 @@ std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInf cnode->set_abstract(op_run_info.abstract); // get output dynamic shape info AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info.is_dynamic_shape), cnode); + if (op_run_info.is_auto_mixed_precision) { + AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info.next_op_name), cnode); + AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info.next_input_index), cnode); + } // set execution order std::vector exe_order = {cnode}; graph->set_execution_order(exe_order); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 9553d2fcd7..4d83f7abb3 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -51,6 +51,9 @@ struct OpRunInfo { AbstractBasePtr abstract; ValuePtr value = nullptr; bool is_dynamic_shape = false; + bool is_auto_mixed_precision = false; + std::string next_op_name = ""; + size_t next_input_index = 0; }; using OpRunInfoPtr = std::shared_ptr; class Executor; diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index 099588f974..0213aca229 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -60,6 +60,9 @@ struct OpExecInfo { py::list op_inputs; py::dict op_attrs; std::vector inputs_mask; + std::string next_op_name = ""; + bool is_mixed_precision_cast = false; + size_t next_input_index = 0; }; using OpExecInfoPtr = std::shared_ptr; OpExecInfoPtr GenerateOpExecInfo(const py::args &args); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index b1c84bd86f..2313cd421b 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -235,58 +235,6 @@ std::string TypeIdToMsTypeStr(const TypeId &type_id) { return type_name->second; } -py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { - py::tuple args(3); - std::string module_name = "mindspore.ops.functional"; - std::string op_name = "cast"; - args[0] = parse::python_adapter::GetPyFn(module_name, op_name); - args[1] = "Cast"; - - std::string dst_type_str = TypeIdToMsTypeStr(type_id); - module_name = "mindspore.common.dtype"; - py::object dst_type = parse::python_adapter::GetPyFn(module_name, dst_type_str); - py::tuple inputs(2); - inputs[0] = arg; - inputs[1] = dst_type; - args[2] = inputs; - - return RunOp(args)[0]; -} - -py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { - MS_EXCEPTION_IF_NULL(is_cast); - auto tensor = py::cast(obj); - auto cast_type = tensor->cast_dtype(); - py::object cast_output = obj; - if (cast_type != nullptr) { - auto source_element = tensor->Dtype(); - if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { - MS_LOG(DEBUG) << "Cast to " << cast_type->ToString(); - cast_output = DoAutoCast(obj, cast_type->type_id()); - *is_cast = true; - } - } - return cast_output; -} - -py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { - MS_EXCEPTION_IF_NULL(is_cast); - auto tuple_size = static_cast(tuple.size()); - py::tuple result(tuple_size); - - for (int64_t i = 0; i < tuple_size; i++) { - if (py::isinstance(tuple[i])) { - MS_LOG(DEBUG) << "Call cast for item " << i; - result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]); - } else if (py::isinstance(tuple[i]) || py::isinstance(tuple[i])) { - result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]); - } else { - result[i] = tuple[i]; - } - } - return std::move(result); -} - bool GetSignatureType(const PrimitivePyPtr &prim, std::vector *dtypes) { MS_EXCEPTION_IF_NULL(dtypes); auto signature = prim->signatures(); @@ -302,69 +250,6 @@ bool GetSignatureType(const PrimitivePyPtr &prim, std::vector &dst_type, - const std::vector &dtypes, const OpExecInfoPtr &op_exec_info) { - const auto &signature = prim->signatures(); - auto &out_args = op_exec_info->op_inputs; - bool has_dtype_sig = !dtypes.empty(); - for (size_t i = 0; i < out_args.size(); ++i) { - MS_LOG(DEBUG) << "Check inputs " << i; - auto obj = out_args[i]; - auto sig = SignatureEnumRW::kRWDefault; - if (!signature.empty()) { - sig = signature[i].rw; - } - bool is_parameter = false; - TypeId arg_type_id = kTypeUnknown; - if (py::isinstance(obj)) { - auto arg = py::cast(obj); - if (arg->is_parameter()) { - is_parameter = true; - MS_LOG(DEBUG) << "Parameter is read " << i; - } - arg_type_id = arg->data_type(); - } - - // No need to implicit cast if no dtype. - if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { - continue; - } - auto it = dst_type.find(dtypes[i]); - if (it == dst_type.end() || it->second == kTypeUnknown) { - continue; - } - // implicit cast - bool is_same_type = false; - bool is_sig_write = (sig == SignatureEnumRW::kRWWrite); - if (arg_type_id != 0) { - is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second); - } - if (is_sig_write) { - if (!is_parameter) { - prim::RaiseExceptionForCheckParameter(prim->name(), i, "not"); - } - if (arg_type_id != 0) { - if (!is_same_type) { - prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id), - TypeIdToMsTypeStr(it->second)); - } - } - } - if (is_same_type) { - continue; - } - - if (!py::isinstance(obj) && !py::isinstance(obj) && !py::isinstance(obj)) { - MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i - << "th input is a not support implicit conversion type: " - << py::cast(obj.attr("__class__").attr("__name__")) << ", and the value is " - << py::cast(obj) << "."; - } - py::object cast_output = DoAutoCast(out_args[i], it->second); - out_args[i] = cast_output; - } -} - void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info, const abstract::AbstractBasePtrList &args_spec_list) { MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list); @@ -694,8 +579,10 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); } py::tuple RunOp(const py::args &args) { auto executor = PynativeExecutor::GetInstance(); MS_EXCEPTION_IF_NULL(executor); + MS_LOG(DEBUG) << "RunOp start " << args.size(); + OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args); try { - return executor->RunOpInner(args); + return executor->RunOpInner(op_exec_info); } catch (const py::error_already_set &ex) { executor->Clean(); // re-throw this exception to Python interpreter to handle it @@ -720,12 +607,9 @@ py::tuple RunOp(const py::args &args) { } } -py::tuple PynativeExecutor::RunOpInner(const py::args &args) { - MS_LOG(DEBUG) << "RunOp start " << args.size(); - OpExecInfoPtr op_exec_info = nullptr; - auto prim = py::cast(args[PY_PRIM]); - auto name = py::cast(args[PY_NAME]); - op_exec_info = GenerateOpExecInfo(args); +py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { + auto prim = op_exec_info->py_primitive; + auto name = op_exec_info->op_name; if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { return RunOpWithInitBackendPolicy(op_exec_info); } @@ -828,8 +712,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires " << "inputs size " << sig_size; } - bool is_cast_op = (op_exec_info->op_name == "Cast"); - if (!is_cast_op) { + if (op_exec_info->op_name != prim::kPrimCast->name()) { RunParameterAutoMixPrecisionCast(op_exec_info); } MS_LOG(DEBUG) << "Make cnode for " << op_exec_info->op_name; @@ -846,7 +729,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v MS_LOG(DEBUG) << "Gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ " << grad_flag_; - AnfNodePtr node = nullptr; + AnfNodePtr input_node = nullptr; abstract::AbstractBasePtr abs = nullptr; auto id = GetId(obj); auto it = node_abs_map_.find(id); @@ -854,11 +737,11 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v abs = it->second; } if (!graph_info_map_.empty()) { - node = GetInput(obj, op_mask); + input_node = GetInput(obj, op_mask); } // update abstract - if (node != nullptr && node->abstract() != nullptr) { - abs = node->abstract(); + if (input_node != nullptr && input_node->abstract() != nullptr) { + abs = input_node->abstract(); } auto const_input_index = prim->get_const_input_indexes(); @@ -880,8 +763,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v node_abs_map_[id] = abs; } (*args_spec_list).emplace_back(abs); - if (node != nullptr) { - inputs.emplace_back(node); + if (input_node != nullptr) { + inputs.emplace_back(input_node); } } @@ -893,9 +776,125 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v return cnode; } +py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, + size_t index) { + py::tuple cast_args(3); + cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast"); + cast_args[PY_NAME] = prim::kPrimCast->name(); + std::string dst_type_str = TypeIdToMsTypeStr(type_id); + py::object dst_type = parse::python_adapter::GetPyFn(kMSDtypeModelName, dst_type_str); + py::tuple inputs(2); + inputs[0] = arg; + inputs[1] = dst_type; + cast_args[PY_INPUTS] = inputs; + auto op_exec = GenerateOpExecInfo(cast_args); + op_exec->is_mixed_precision_cast = true; + op_exec->next_op_name = op_name; + op_exec->next_input_index = index; + return RunOpInner(op_exec)[0]; +} + +py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, + size_t index) { + MS_EXCEPTION_IF_NULL(is_cast); + auto tensor = py::cast(obj); + auto cast_type = tensor->cast_dtype(); + py::object cast_output = obj; + if (cast_type != nullptr) { + auto source_element = tensor->Dtype(); + if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { + MS_LOG(DEBUG) << "Cast to " << cast_type->ToString(); + *is_cast = true; + return DoAutoCast(obj, cast_type->type_id(), op_name, index); + } + } + return cast_output; +} + +py::object PynativeExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple, + const std::string &op_name, size_t index) { + MS_EXCEPTION_IF_NULL(is_cast); + auto tuple_size = static_cast(tuple.size()); + py::tuple result(tuple_size); + + for (int64_t i = 0; i < tuple_size; i++) { + if (py::isinstance(tuple[i])) { + MS_LOG(DEBUG) << "Call cast for item " << i; + result[i] = DoParamMixPrecisionCast(is_cast, tuple[i], op_name, index); + } else if (py::isinstance(tuple[i]) || py::isinstance(tuple[i])) { + result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i], op_name, index); + } else { + result[i] = tuple[i]; + } + } + return std::move(result); +} + +void PynativeExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map &dst_type, + const std::vector &dtypes, + const OpExecInfoPtr &op_exec_info) { + const auto &signature = prim->signatures(); + auto &out_args = op_exec_info->op_inputs; + for (size_t i = 0; i < out_args.size(); ++i) { + // No need to implicit cast if no dtype. + if (dtypes.empty() || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { + continue; + } + auto it = dst_type.find(dtypes[i]); + if (it == dst_type.end() || it->second == kTypeUnknown) { + continue; + } + MS_LOG(DEBUG) << "Check inputs " << i; + auto obj = out_args[i]; + auto sig = SignatureEnumRW::kRWDefault; + if (!signature.empty()) { + sig = signature[i].rw; + } + bool is_parameter = false; + TypeId arg_type_id = kTypeUnknown; + if (py::isinstance(obj)) { + auto arg = py::cast(obj); + if (arg->is_parameter()) { + is_parameter = true; + MS_LOG(DEBUG) << "Parameter is read " << i; + } + arg_type_id = arg->data_type(); + } + // implicit cast + bool is_same_type = false; + if (arg_type_id != kTypeUnknown) { + is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second); + } + if (sig == SignatureEnumRW::kRWWrite) { + if (!is_parameter) { + prim::RaiseExceptionForCheckParameter(prim->name(), i, "not"); + } + if (arg_type_id != kTypeUnknown) { + if (!is_same_type) { + prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id), + TypeIdToMsTypeStr(it->second)); + } + } + } + if (is_same_type) { + continue; + } + + if (!py::isinstance(obj) && !py::isinstance(obj) && !py::isinstance(obj)) { + MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i + << "th input is a not support implicit conversion type: " + << py::cast(obj.attr("__class__").attr("__name__")) << ", and the value is " + << py::cast(obj) << "."; + } + py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i); + out_args[i] = cast_output; + } +} + void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) { size_t size = op_exec_info->op_inputs.size(); auto prim = op_exec_info->py_primitive; + MS_EXCEPTION_IF_NULL(prim); const auto &signature = prim->signatures(); for (size_t i = 0; i < size; i++) { auto obj = op_exec_info->op_inputs[i]; @@ -916,10 +915,10 @@ void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_ } } // redundant cast call if the tensor is a const Tensor. - cast_output = DoParamMixPrecisionCast(&is_cast, obj); + cast_output = DoParamMixPrecisionCast(&is_cast, obj, prim->name(), i); } else if (py::isinstance(obj) || py::isinstance(obj)) { // mix precision for tuple inputs - cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj); + cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj, prim->name(), i); } if (is_cast) { op_exec_info->op_inputs[i] = cast_output; @@ -958,7 +957,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { free_param->set_default_param(value); MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; graph_info_map_[df_builder_].params.emplace(obj_id); - set_node_map(df_builder_, obj_id, free_param); + SetNodeMapInGraphInfoMap(df_builder_, obj_id, free_param); return free_param; } return graph_info_map_[df_builder_].node_map[obj_id].first; @@ -969,7 +968,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { // out = op(op1(x, y)) // out = op(cell1(x, y)) // out = op(cell1(x, y)[0]) - node = GetObjNode(obj, obj_id); + return GetObjNode(obj, obj_id); } else if (py::isinstance(obj)) { // out = op((x, y)) // out = cell((x, y)) @@ -985,7 +984,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { args.emplace_back(GetInput(tuple[i], false)); } auto cnode = curr_g_->NewCNode(args); - set_node_map(curr_g_, GetId(obj), cnode); + SetNodeMapInGraphInfoMap(curr_g_, GetId(obj), cnode); node = cnode; } else { node = MakeValueNode(obj, obj_id); @@ -1048,7 +1047,7 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str ValuePtr converted_ret = nullptr; parse::ConvertData(obj, &converted_ret); auto node = NewValueNode(converted_ret); - set_node_map(curr_g_, obj_id, node); + SetNodeMapInGraphInfoMap(curr_g_, obj_id, node); return node; } @@ -1083,12 +1082,12 @@ void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::ob if (size > 1) { for (int64_t i = 0; i < size; ++i) { auto value_id = GetId(value[i]); - set_node_map(curr_g_, value_id, cnode, i); + SetNodeMapInGraphInfoMap(curr_g_, value_id, cnode, i); } } } - set_node_map(curr_g_, obj_id, cnode); - set_pyobj(curr_g_, obj_id); + SetNodeMapInGraphInfoMap(curr_g_, obj_id, cnode); + SetPyObjInGraphInfoMap(curr_g_, obj_id); } void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { @@ -1305,8 +1304,10 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); // get graph info for checking it whether existing in the cache std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); - session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, op_exec_info->abstract, - op_exec_info->value, op_exec_info->is_dynamic_shape}; + session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, + op_exec_info->abstract, op_exec_info->value, + op_exec_info->is_dynamic_shape, op_exec_info->is_mixed_precision_cast, + op_exec_info->next_op_name, op_exec_info->next_input_index}; session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); EraseValueNodeTensor(tensors_mask, &input_tensors); VectorRef outputs; @@ -1318,15 +1319,15 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati return result; } -void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); } +void PynativeExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); } -void PynativeExecutor::Popp() { - if (graph_context_.empty()) { - MS_LOG(EXCEPTION) << "Stack graph_context_ is empty"; +void PynativeExecutor::PopGraphStack() { + if (graph_stack_.empty()) { + MS_LOG(EXCEPTION) << "Stack graph_stack_ is empty"; } - graph_context_.pop(); - if (!graph_context_.empty()) { - curr_g_ = graph_context_.top(); + graph_stack_.pop(); + if (!graph_stack_.empty()) { + curr_g_ = graph_stack_.top(); } } @@ -1468,7 +1469,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id; // check whether cell needed to construct grad graph - if (graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { + if (graph_stack_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { auto it = cell_resource_map_.find(cell_id); if (it != cell_resource_map_.end()) { resource_ = it->second; @@ -1479,22 +1480,21 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg } // init resource for constructing forward graph and grad graph auto g = std::make_shared(); - if (graph_context_.empty()) { + if (graph_stack_.empty()) { MakeNewTopGraph(cell_id, args, g); - } else { - MS_EXCEPTION_IF_NULL(df_builder_); - curr_g_ = g; } - Pushp(); + MS_EXCEPTION_IF_NULL(df_builder_); + curr_g_ = g; + PushCurrentGraphToStack(); if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) { graph_info_map_.emplace(curr_g_, GraphInfo()); } for (size_t i = 0; i < args.size(); ++i) { auto param = args[i]; auto new_param = g->add_parameter(); - std::string param_obj = GetId(param); - set_node_map(curr_g_, param, new_param, true); - set_node_map(curr_g_, param_obj, new_param); + std::string param_id = GetId(param); + SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true); + SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param); } // check whether the constrcut of cell will be changed if (!dynamic_cell_) { @@ -1525,46 +1525,47 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar top_graph_cells_.emplace(cell_id); } -void PynativeExecutor::set_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, - bool is_param) { - if (!py::isinstance(node) && !py::isinstance(node)) { +void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, + bool is_param) { + if (!py::isinstance(args) && !py::isinstance(args)) { return; } - auto tuple = node.cast(); + auto tuple = args.cast(); auto tuple_size = static_cast(tuple.size()); for (int64_t i = 0; i < tuple_size; ++i) { auto id = GetId(tuple[i]); if (is_param) { graph_info_map_[g].params.emplace(id); } - set_node_map(g, id, cnode, i); - set_tuple_node_map(g, tuple[i], cnode, std::vector{i}, is_param); + SetNodeMapInGraphInfoMap(g, id, node, i); + SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector{i}, is_param); } } -void PynativeExecutor::set_tuple_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, - const std::vector &idx, bool is_param) { - if (!py::isinstance(node) && !py::isinstance(node)) { +void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, + const AnfNodePtr &node, + const std::vector &index_sequence, bool is_param) { + if (!py::isinstance(args) && !py::isinstance(args)) { return; } - auto tuple = node.cast(); + auto tuple = args.cast(); auto tuple_size = static_cast(tuple.size()); for (int64_t i = 0; i < tuple_size; ++i) { - std::vector tmp = idx; + std::vector tmp = index_sequence; tmp.emplace_back(i); auto id = GetId(tuple[i]); if (is_param) { graph_info_map_[g].params.emplace(id); } - set_node_map(g, id, cnode, tmp); - set_tuple_node_map(g, tuple[i], cnode, tmp, is_param); + SetNodeMapInGraphInfoMap(g, id, node, tmp); + SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param); } } void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; - if (graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { + if (graph_stack_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { MS_LOG(DEBUG) << "Endgraph already compiled"; return; } @@ -1582,8 +1583,8 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o inputs.emplace_back(GetInput(tuple[i], false)); } auto cnode = curr_g_->NewCNode(inputs); - set_node_map(curr_g_, out, cnode); - set_node_map(curr_g_, out_id, cnode); + SetTupleArgsToGraphInfoMap(curr_g_, out, cnode); + SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode); } else { MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id; MakeValueNode(out, out_id); @@ -1601,21 +1602,21 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje auto newfg = MakeGradGraph(cell, args); - if (graph_context_.size() > 1) { + if (graph_stack_.size() > 1) { std::vector inputs; inputs.emplace_back(NewValueNode(curr_g_)); - Popp(); + PopGraphStack(); // connect the previous graph to the inside graph - auto graph_prev = graph_context_.top(); + auto graph_prev = graph_stack_.top(); for (size_t i = 0; i < args.size(); i++) { auto input = GetInput(args[i], false); inputs.emplace_back(input); } auto out_cnode = graph_prev->NewCNode(inputs); - set_pyobj(graph_prev, GetCellId(cell, args)); - set_node_map(graph_prev, out, out_cnode); - set_node_map(graph_prev, GetId(out), out_cnode); + SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); + SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); + SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); } else { if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { DumpIR("before_resolve.ir", newfg); @@ -1625,7 +1626,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje DumpIR("after_resolve.ir", newfg); } resource_->set_func_graph(newfg); - Popp(); + PopGraphStack(); } } @@ -1647,7 +1648,7 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::a } } // Obtain grad graph - auto newfg = ad::Grad(curr_g_, resource_, graph_context_.size() == 1); + auto newfg = ad::Grad(curr_g_, resource_, graph_stack_.size() == 1); graph_info_map_.erase(curr_g_); if (need_replace_param) { @@ -1986,7 +1987,7 @@ void PynativeExecutor::Clear(const std::string &flag) { op_id_map_.clear(); obj_to_forward_id_.clear(); node_abs_map_.clear(); - std::stack().swap(graph_context_); + std::stack().swap(graph_stack_); ConfigManager::GetInstance().ResetIterNum(); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 9b707db0ad..7f1d6705db 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -81,7 +81,7 @@ class PynativeExecutor : public std::enable_shared_from_this { bool grad_flag() { return grad_flag_; } void set_grad_flag(bool flag) { grad_flag_ = flag; } - py::tuple RunOpInner(const py::args &args); + py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); void NewGraph(const py::object &cell, const py::args &args); py::object Run(const py::tuple &args, const py::object &phase); py::object CheckGraph(const py::object &cell, const py::args &args); @@ -108,6 +108,12 @@ class PynativeExecutor : public std::enable_shared_from_this { std::string ParseNodeName(const std::shared_ptr &ast, const py::object &node, parse::AstMainType type); + py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, size_t index); + py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple, const std::string &op_name, + size_t index); + py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index); + void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map &dst_type, + const std::vector &dtypes, const OpExecInfoPtr &op_exec_info); // run op AnfNodePtr GetInput(const py::object &obj, bool op_mask); MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info); @@ -129,8 +135,8 @@ class PynativeExecutor : public std::enable_shared_from_this { void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); // construct grad graph - void Pushp(); - void Popp(); + void PushCurrentGraphToStack(); + void PopGraphStack(); void NewGraphInner(const py::object &cell, const py::args &args); void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g); void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); @@ -148,19 +154,17 @@ class PynativeExecutor : public std::enable_shared_from_this { abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); // hold graph(forward and grad) info - void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } - void set_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, bool is_param = false); - void set_node_map(const FuncGraphPtr &g, const std::string &obj, AnfNodePtr node) { - graph_info_map_[g].node_map[obj] = std::make_pair(node, std::vector{-1}); + void SetPyObjInGraphInfoMap(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } + void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, + bool is_param = false); + void SetNodeMapInGraphInfoMap(FuncGraphPtr g, const std::string id, AnfNodePtr node, int64_t index = -1) { + graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector{index}); } - void set_node_map(const FuncGraphPtr &g, const std::string &obj, AnfNodePtr node, int index) { - graph_info_map_[g].node_map[obj] = std::make_pair(node, std::vector{index}); + void SetNodeMapInGraphInfoMap(FuncGraphPtr g, const std::string id, AnfNodePtr node, std::vector index) { + graph_info_map_[g].node_map[id] = std::make_pair(node, index); } - void set_node_map(const FuncGraphPtr &g, const std::string &obj, AnfNodePtr node, std::vector index) { - graph_info_map_[g].node_map[obj] = std::make_pair(node, index); - } - void set_tuple_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, - const std::vector &idx, bool is_param = false); + void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, + const std::vector &index_sequence, bool is_param = false); static std::shared_ptr executor_; static std::mutex instance_lock_; @@ -176,7 +180,7 @@ class PynativeExecutor : public std::enable_shared_from_this { FuncGraphPtr df_builder_{nullptr}; ResourcePtr resource_{nullptr}; // Records forwrad graph, the bottom is top graph - std::stack graph_context_; + std::stack graph_stack_; std::unordered_set top_graph_cells_; // record all info of a graph @@ -195,6 +199,8 @@ class PynativeExecutor : public std::enable_shared_from_this { std::unordered_map obj_to_forward_id_; std::unordered_map node_abs_map_; std::unordered_map prim_abs_list_; + const inline static std::string kOpsFunctionModelName = "mindspore.ops.functional"; + const inline static std::string kMSDtypeModelName = "mindspore.common.dtype"; }; using PynativeExecutorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index 3108471cdf..96841b06cb 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -50,8 +50,10 @@ enum MatchCountPriority : int { MATCH_OUTPUT_DTYPE_COUNT, MATCH_COUNT_PRIORITY_END }; - -const int kUnSupportMixedDataTypeIndex = -1; +const std::map> kNextOpFormatList = { + {prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}, + {prim::kPrimFusedBatchNorm->name(), + {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}}}; bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { MS_EXCEPTION_IF_NULL(cnode); @@ -313,10 +315,41 @@ std::vector> FilterRaisedOrReducePrecis } return filtered_kernel_info_list; } -} // namespace -void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { +void SetCastAndWeightFormat(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + if (!AnfAlgo::HasNodeAttr(kAttrPynativeNextIndex, kernel_node) || + !AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) { + MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString() << "] attr of " << kAttrPynativeNextIndex << " or " + << kAttrPynativeNextOpName << " has been not setted yet!"; + } + auto next_index = AnfAlgo::GetNodeAttr(kernel_node, kAttrPynativeNextIndex); + auto next_op_name = AnfAlgo::GetNodeAttr(kernel_node, kAttrPynativeNextOpName); + auto iter = kNextOpFormatList.find(next_op_name); + if (iter == kNextOpFormatList.end()) { + MS_LOG(WARNING) << "The op name " << next_op_name << "has been not setted in the next op map "; + return; + } + if (iter->second.size() < next_index) { + MS_LOG(EXCEPTION) << "Next input index " << next_index << "is out of range in the next op map max size is " + << iter->second.size(); + } + if (AnfAlgo::GetCNodeName(kernel_node) != prim::kPrimCast->name()) { + MS_LOG(INFO) << "Only supported to change the node Cast's build info!!!"; + return; + } + auto format = iter->second[next_index]; + auto info_builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(kernel_node)); + info_builder->SetInputsFormat({format}); + info_builder->SetOutputsFormat({format}); + AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), kernel_node.get()); +} +} // namespace +void SetTensorDeviceInfo(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); + auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node); + MS_EXCEPTION_IF_NULL(selected_kernel_info); for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); MS_EXCEPTION_IF_NULL(input_kernel_node); @@ -329,7 +362,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { continue; } - if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) { + if (selected_kernel_info->GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) { continue; } // we set special device info of a input tensor. @@ -344,17 +377,17 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co auto builder = std::make_shared(); if (IsValueNode(input_kernel_node) && AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + std::vector output_format = {selected_kernel_info->GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + std::vector output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; builder->SetOutputsDeviceType(output_type); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); continue; } if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + std::vector output_format = {selected_kernel_info->GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + std::vector output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; builder->SetOutputsDeviceType(output_type); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); } @@ -388,7 +421,10 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, // Set kernel info to the anfnode AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); // Set format and data type for input tensor. - SetTensorDeviceInfo(*selected_kernel_info, kernel_node); + if (AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) { + SetCastAndWeightFormat(kernel_node); + } + SetTensorDeviceInfo(kernel_node); return select_status; } @@ -428,7 +464,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list); AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); // Set format and data type for input tensor. - SetTensorDeviceInfo(*selected_kernel_info, kernel_node); + SetTensorDeviceInfo(kernel_node); } else { MS_LOG(WARNING) << " <<<"; MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h index 82bf5e4f75..064d898500 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h @@ -29,7 +29,7 @@ enum KernelSelectStatus { }; KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); -void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node); +void SetTensorDeviceInfo(const CNodePtr &kernel_node); void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph); } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc index c9b3cab1c6..57c85d361f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc @@ -95,7 +95,7 @@ void UpdateKernelInfo(const std::vector &node_list) { auto selected_kernel_info_ptr = kernel_info_list[index]; ResetKernelBuildInfo(cnode); AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get()); - SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode); + SetTensorDeviceInfo(cnode); break; } } @@ -477,7 +477,7 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector