diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc index 05a4255703..7067ab74c2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc @@ -199,6 +199,5 @@ void SetAkgKernelAttrs(const AnfNodePtr &anf_node) { it->second(anf_node); } } - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h index c20849989b..dc4bc760fb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h @@ -20,9 +20,7 @@ namespace mindspore { namespace kernel { - void SetAkgKernelAttrs(const AnfNodePtr &anf_node); - } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc index acaa2dea54..88a790bc1e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc @@ -42,153 +42,276 @@ namespace mindspore { namespace kernel { namespace { -ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) { - if (type == "str") { - std::string value = attr_json[kJsonKeyValue]; - return MakeValue(value); - } else if (type == "int") { - int value = attr_json[kJsonKeyValue]; - return MakeValue(value); - } else if (type == "bool") { - bool value = attr_json[kJsonKeyValue]; - return MakeValue(value); - } else if (type == "float") { - float value = attr_json[kJsonKeyValue]; - return MakeValue(value); - } else if (type == "listInt") { - std::vector value = attr_json[kJsonKeyValue]; - return MakeValue(value); - } else if (type == "listStr") { - std::vector value = attr_json[kJsonKeyValue]; - return MakeValue(value); - } else { - MS_LOG(ERROR) << "Unknown type of attr: " << type << ", json: \n" << attr_json; - return nullptr; - } -} +constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; +constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; -bool DecodeAttrs(const nlohmann::json &attrs_json, std::map *attrs) { - MS_EXCEPTION_IF_NULL(attrs); - MS_LOG(DEBUG) << "start decode attrs, " << attrs_json; - // decode attrs. - if (attrs_json.find(kJsonKeyAttr) == attrs_json.end() || attrs_json[kJsonKeyAttr].is_null()) { - // attrs maybe empty - return true; +class CNodeDecoder { + public: + explicit CNodeDecoder(std::map *nodes_map) : nodes_map_(*nodes_map) {} + ~CNodeDecoder() = default; + CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, kernel::Processor processor) { + MS_LOG(DEBUG) << "start decode cnode, " << cnode_json; + // decode attrs. + if (!DecodeAttrs(cnode_json)) { + MS_LOG(ERROR) << "Decode attrs failed."; + return nullptr; + } + if (!DecodeInputDesc(cnode_json, func_graph) || cnode_ == nullptr) { + MS_LOG(ERROR) << "Decode inputs failed."; + return nullptr; + } + if (!DecodeOutputDesc(cnode_json, func_graph)) { + MS_LOG(ERROR) << "Decode outputs failed."; + return nullptr; + } + CreateKernelInfo(processor); + return cnode_; } - std::vector attr_descs = attrs_json[kJsonKeyAttr]; - for (const auto &attr_desc : attr_descs) { - std::string name = attr_desc[kJsonKeyName]; - std::string type = attr_desc[kJsonKeyDataType]; - auto value = ParseValue(attr_desc, type); - if (value == nullptr) { - return false; + private: + ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) { + if (type == "str") { + std::string value = attr_json[kJsonKeyValue]; + return MakeValue(value); + } else if (type == "int") { + int value = attr_json[kJsonKeyValue]; + return MakeValue(value); + } else if (type == "bool") { + bool value = attr_json[kJsonKeyValue]; + return MakeValue(value); + } else if (type == "float") { + float value = attr_json[kJsonKeyValue]; + return MakeValue(value); + } else if (type == "listInt") { + std::vector value = attr_json[kJsonKeyValue]; + return MakeValue(value); + } else if (type == "listStr") { + std::vector value = attr_json[kJsonKeyValue]; + return MakeValue(value); + } else { + MS_LOG(ERROR) << "Unknown type of attr: " << type << ", json: \n" << attr_json; + return nullptr; } - (*attrs)[name] = value; } - return true; -} + bool DecodeAttrs(const nlohmann::json &attrs_json) { + MS_LOG(DEBUG) << "start decode attrs, " << attrs_json; + // attrs maybe empty + if (attrs_json.find(kJsonKeyAttr) == attrs_json.end() || attrs_json[kJsonKeyAttr].is_null()) { + return true; + } -// python utils. -constexpr auto kGetPythonOpFunc = "_get_python_op"; -constexpr auto kParallelUtilsModule = "mindspore.parallel._utils"; -// almost all ops are defined in this path. -constexpr auto kOperationsModule = "mindspore.ops.operations"; + std::vector attr_descs = attrs_json[kJsonKeyAttr]; + for (const auto &attr_desc : attr_descs) { + std::string name = attr_desc[kJsonKeyName]; + std::string type = attr_desc[kJsonKeyDataType]; + auto value = ParseValue(attr_desc, type); + if (value == nullptr) { + return false; + } + cnode_attrs_[name] = value; + } + return true; + } -const std::map> op_attrs_map = { - {kReduceSumOpName, std::vector{kAttrKeepDims}}, - {kReduceMaxOpName, std::vector{kAttrKeepDims}}, - {kReduceMinOpName, std::vector{kAttrKeepDims}}, -}; + bool DecodeInputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { + std::string op_name = cnode_json[kJsonKeyName]; + // new primitive. + auto primitive = GetPrimitive(op_name); + if (primitive == nullptr) { + MS_LOG(ERROR) << "Create primitive failed."; + return false; + } -ValuePtr CreatOpInstance(const std::string &op_name, const std::vector &attrs) { - py::module mod = py::module::import(kOperationsModule); - if (!py::hasattr(mod, op_name.c_str())) { - MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name; - return nullptr; + // collect inputs. + auto primitive_v = NewValueNode(primitive); + func_graph->AddValueNode(primitive_v); + std::vector inputs{primitive_v}; + std::vector input_descs = cnode_json[kJsonKeyInputDesc]; + for (size_t i = 0; i < input_descs.size(); ++i) { + nlohmann::json input_desc = input_descs[i][0]; + std::string name = input_desc[kJsonKeyTensorName]; + if (input_desc.find(kJsonKeyValue) != input_desc.end()) { + inputs.push_back(DecodeValueNode(input_desc, func_graph)); + } else if (nodes_map_.count(name) == 0) { + MS_LOG(ERROR) << "Input: " << name << " of: " << op_name << " not found."; + return false; + } else { + inputs.push_back(nodes_map_[name]); + } + input_formats_.push_back(input_desc[kJsonKeyFormat]); + input_types_.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType])); + } + // new cnode. + cnode_ = func_graph->NewCNode(inputs); + func_graph->AddNode(cnode_); + return true; } - std::vector arg_list; - (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), - [](const ValuePtr &attr) { return ValuePtrToPyData(attr); }); - py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule, - op_name, arg_list); - ValuePtr op_instance = nullptr; - bool succ = parse::ConvertData(obj, &op_instance); - if (!succ) { - MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed."; - return nullptr; + + bool DecodeOutputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) { + std::vector output_descs = cnode_json[kJsonKeyOutputDesc]; + AbstractBasePtr abstract(nullptr); + if (output_descs.empty()) { + MS_LOG(ERROR) << "No outputs found."; + return false; + } else if (output_descs.size() == 1) { + // single output. + nlohmann::json output_desc = output_descs[0]; + output_formats_.push_back(output_desc[kJsonKeyFormat]); + output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); + nodes_map_[output_desc[kJsonKeyTensorName]] = cnode_; + } else { + // multi outputs. + for (size_t j = 0; j < output_descs.size(); ++j) { + nlohmann::json output_desc = output_descs[j]; + output_formats_.push_back(output_desc[kJsonKeyFormat]); + output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); + auto get_item = + func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_, NewValueNode(SizeToInt(j))}); + func_graph->AddNode(get_item); + nodes_map_[output_desc[kJsonKeyTensorName]] = get_item; + } + } + return true; } - return op_instance; -} -PrimitivePtr GetPrimitive(const std::string &op_name, const std::map &attrs_val) { - PrimitivePtr primitive{nullptr}; - if (op_attrs_map.count(op_name) == 0) { - // no attrs for op instance. - primitive = CreatOpInstance(op_name, std::vector{})->cast(); - } else { - // make attrs for op instance. - std::vector op_attrs; - const auto &attr_names = op_attrs_map.at(op_name); - for (const auto &attr_name : attr_names) { - if (attrs_val.count(attr_name) == 0) { - MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found."; - return nullptr; + void CreateKernelInfo(kernel::Processor processor) { + auto kernel_info = std::make_shared(); + std::vector feature_map_input_indexs; + // if the node only has the primitive(such as getNext) or the node's input has a feature map input + // then the node's output is a feature map output + const auto &inputs = cnode_->inputs(); + for (size_t index = 1; index < inputs.size(); ++index) { + auto node = AnfAlgo::VisitKernel(inputs[index], 0); + if (AnfAlgo::IsFeatureMapOutput(node.first)) { + feature_map_input_indexs.push_back(index); } - op_attrs.push_back(attrs_val.at(attr_name)); } - primitive = CreatOpInstance(op_name, op_attrs)->cast(); + if (AnfAlgo::GetCNodeName(cnode_) == prim::kPrimCast->name()) { + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode_); + } + if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { + kernel_info->SetFeatureMapFlag(true); + } + if (AnfAlgo::IsRealCNodeKernel(cnode_)) { + AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_); + AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode_); + } + cnode_->set_kernel_info(kernel_info); + // create kernel_build_info. + auto builder = std::make_shared(); + builder->SetInputsFormat(input_formats_); + builder->SetInputsDeviceType(input_types_); + builder->SetOutputsFormat(output_formats_); + builder->SetOutputsDeviceType(output_types_); + builder->SetProcessor(processor); + builder->SetKernelType(KernelType::AKG_KERNEL); + builder->SetFusionType(kernel::FusionType::OPAQUE); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get()); } - if (primitive != nullptr) { - for (const auto &attr : attrs_val) { - primitive->AddAttr(attr.first, attr.second); + ValuePtr CreatOpInstance(const std::string &op_name, const std::vector &attrs) { + // python utils. + constexpr auto kGetPythonOpFunc = "_get_python_op"; + constexpr auto kParallelUtilsModule = "mindspore.parallel._utils"; + // almost all ops are defined in this path. + constexpr auto kOperationsModule = "mindspore.ops.operations"; + py::module mod = py::module::import(kOperationsModule); + if (!py::hasattr(mod, op_name.c_str())) { + MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name; + return nullptr; } + std::vector arg_list; + (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), + [](const ValuePtr &attr) { return ValuePtrToPyData(attr); }); + py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule, + op_name, arg_list); + ValuePtr op_instance = nullptr; + bool succ = parse::ConvertData(obj, &op_instance); + if (!succ) { + MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed."; + return nullptr; + } + return op_instance; } - return primitive; -} -} // namespace + const std::map> op_attrs_map_ = { + {kReduceSumOpName, std::vector{kAttrKeepDims}}, + {kReduceMaxOpName, std::vector{kAttrKeepDims}}, + {kReduceMinOpName, std::vector{kAttrKeepDims}}, + }; -constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; -constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; + PrimitivePtr GetPrimitive(const std::string &op_name) { + PrimitivePtr primitive{nullptr}; + if (op_attrs_map_.count(op_name) == 0) { + // no attrs for op instance. + primitive = CreatOpInstance(op_name, std::vector{})->cast(); + } else { + // make attrs for op instance. + std::vector op_attrs; + const auto &attr_names = op_attrs_map_.at(op_name); + for (const auto &attr_name : attr_names) { + if (cnode_attrs_.count(attr_name) == 0) { + MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found."; + return nullptr; + } + op_attrs.push_back(cnode_attrs_.at(attr_name)); + } + primitive = CreatOpInstance(op_name, op_attrs)->cast(); + } + if (primitive != nullptr) { + for (const auto &attr : cnode_attrs_) { + primitive->AddAttr(attr.first, attr.second); + } + } + return primitive; + } -ScalarPtr AkgKernelJsonDecoder::DecodeScalar(const nlohmann::json &scalar_json) { - auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); - switch (type_id) { - case kNumberTypeFloat16: - case kNumberTypeFloat32: - return std::make_shared(scalar_json[kJsonKeyValue]); - case kNumberTypeInt32: - return std::make_shared(scalar_json[kJsonKeyValue]); - default: - MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType]; - break; + ScalarPtr DecodeScalar(const nlohmann::json &scalar_json) { + auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]); + switch (type_id) { + case kNumberTypeFloat16: + case kNumberTypeFloat32: + return std::make_shared(scalar_json[kJsonKeyValue]); + case kNumberTypeInt32: + return std::make_shared(scalar_json[kJsonKeyValue]); + default: + MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType]; + break; + } + return nullptr; } - return nullptr; -} -ValueNodePtr AkgKernelJsonDecoder::DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) { - MS_LOG(DEBUG) << "start decode value node, " << value_json; - auto scalar = DecodeScalar(value_json); - auto tensor = ScalarToTensor(scalar); + ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) { + MS_LOG(DEBUG) << "start decode value node, " << value_json; + auto scalar = DecodeScalar(value_json); + auto tensor = ScalarToTensor(scalar); - auto value_node = std::make_shared(tensor); - value_node->set_abstract(tensor->ToAbstract()); - // create kernel_info fo new value node. - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node. - auto builder = std::make_shared(); - // layout info. - builder->SetOutputsFormat(std::vector{value_json[kJsonKeyFormat]}); - builder->SetOutputsDeviceType(std::vector{DtypeToTypeId(value_json[kJsonKeyDataType])}); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get()); - func_graph->AddValueNode(value_node); - MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2); - return value_node; -} + auto value_node = std::make_shared(tensor); + value_node->set_abstract(tensor->ToAbstract()); + // create kernel_info fo new value node. + auto kernel_info = std::make_shared(); + value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node. + auto builder = std::make_shared(); + // layout info. + builder->SetOutputsFormat(std::vector{value_json[kJsonKeyFormat]}); + builder->SetOutputsDeviceType(std::vector{DtypeToTypeId(value_json[kJsonKeyDataType])}); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get()); + func_graph->AddValueNode(value_node); + MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2); + return value_node; + } + + std::map &nodes_map_; + std::map cnode_attrs_; + std::vector input_formats_; + std::vector output_formats_; + std::vector input_types_; + std::vector output_types_; + CNodePtr cnode_{nullptr}; +}; +} // namespace ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph) { @@ -208,118 +331,35 @@ ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶met CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor) { + CNodeDecoder decoder(&nodes_map_); Processor p = kernel::GetProcessor(processor); - MS_LOG(DEBUG) << "start decode cnode, " << cnode_json; - // decode attrs. - std::map cnode_attrs; - if (!DecodeAttrs(cnode_json, &cnode_attrs)) { - MS_LOG(ERROR) << "Error decode attrs."; - return nullptr; - } - std::string op_name = cnode_json[kJsonKeyName]; - // new primitive. - auto primitive = GetPrimitive(op_name, cnode_attrs); - if (primitive == nullptr) { - MS_LOG(ERROR) << "Create primitive failed."; - return nullptr; - } - - // data layout info. - std::vector input_formats; - std::vector input_types; - std::vector output_formats; - std::vector output_types; + return decoder.DecodeCNode(cnode_json, func_graph, p); +} - // collect inputs. - auto primitive_v = NewValueNode(primitive); - func_graph->AddValueNode(primitive_v); - std::vector inputs{primitive_v}; - std::vector input_descs = cnode_json[kJsonKeyInputDesc]; - for (size_t i = 0; i < input_descs.size(); ++i) { - nlohmann::json input_desc = input_descs[i][0]; - std::string name = input_desc[kJsonKeyTensorName]; - if (input_desc.find(kJsonKeyValue) != input_desc.end()) { - inputs.push_back(DecodeValueNode(input_desc, func_graph)); - } else if (nodes_map_.count(name) == 0) { - MS_LOG(ERROR) << "Input: " << name << " of: " << op_name << " not found."; +AnfNodePtr AkgKernelJsonDecoder::DecodeOutput(const std::vector &output_descs, + const FuncGraphPtr &func_graph) { + std::vector outputs{NewValueNode(prim::kPrimMakeTuple)}; + for (const auto &output_desc : output_descs) { + std::string name = output_desc[kJsonKeyTensorName]; + if (nodes_map_.count(name) == 0) { + MS_LOG(ERROR) << "Output: " << name << " of graph not found."; return nullptr; - } else { - inputs.push_back(nodes_map_[name]); } - input_formats.push_back(input_desc[kJsonKeyFormat]); - input_types.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType])); + outputs.push_back(nodes_map_[name]); } - MS_LOG(DEBUG) << "decode inputs success."; - - // new cnode. - auto cnode = func_graph->NewCNode(inputs); - func_graph->AddNode(cnode); - - // decode outputs. - std::vector output_descs = cnode_json[kJsonKeyOutputDesc]; - AbstractBasePtr abstract(nullptr); - if (output_descs.empty()) { - MS_LOG(ERROR) << "No outputs found."; - return nullptr; - } else if (output_descs.size() == 1) { - // single output. - nlohmann::json output_desc = output_descs[0]; - output_formats.push_back(output_desc[kJsonKeyFormat]); - output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); - nodes_map_[output_desc[kJsonKeyTensorName]] = cnode; + if (outputs.size() == 2) { + func_graph->set_output(outputs[1]); } else { - // multi outputs. - for (size_t j = 0; j < output_descs.size(); ++j) { - nlohmann::json output_desc = output_descs[j]; - output_formats.push_back(output_desc[kJsonKeyFormat]); - output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType])); - auto get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, NewValueNode(SizeToInt(j))}); - func_graph->AddNode(get_item); - nodes_map_[output_desc[kJsonKeyTensorName]] = get_item; - } - } - MS_LOG(DEBUG) << "decode outputs success."; - - // create kernel_info. - auto kernel_info = std::make_shared(); - std::vector feature_map_input_indexs; - // if the node only has the primitive(such as getNext) or the node's input has a feature map input - // then the node's output is a feature map output - for (size_t index = 1; index < inputs.size(); ++index) { - auto node = AnfAlgo::VisitKernel(inputs[index], 0); - if (AnfAlgo::IsFeatureMapOutput(node.first)) { - feature_map_input_indexs.push_back(index); - } + auto output = func_graph->NewCNode(outputs); + func_graph->AddNode(output); + func_graph->set_output(output); } - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { - AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); - } - if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { - kernel_info->SetFeatureMapFlag(true); - } - if (AnfAlgo::IsRealCNodeKernel(cnode)) { - AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); - AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); - } - cnode->set_kernel_info(kernel_info); - // create kernel_build_info. - auto builder = std::make_shared(); - builder->SetInputsFormat(input_formats); - builder->SetInputsDeviceType(input_types); - builder->SetOutputsFormat(output_formats); - builder->SetOutputsDeviceType(output_types); - builder->SetProcessor(p); - builder->SetKernelType(KernelType::AKG_KERNEL); - builder->SetFusionType(kernel::FusionType::OPAQUE); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get()); - return cnode; + return func_graph->output(); } FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel_json) { MS_LOG(DEBUG) << "start decode, " << kernel_json; - // clear cache. nodes_map_.clear(); - // create a graph. auto graph = std::make_shared(); // decode parameters. @@ -331,10 +371,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel for (size_t i = 0; i < input_descs.size(); ++i) { std::vector input_desc = input_descs[i]; auto parameter = DecodeParameter(input_desc[0], graph); - if (parameter == nullptr) { - MS_LOG(ERROR) << "Error decode parameter."; - return nullptr; - } + MS_EXCEPTION_IF_NULL(parameter); } MS_LOG(DEBUG) << "decode parameters success."; @@ -346,10 +383,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel } for (const auto &op_desc : op_node_descs) { auto op_node = DecodeCNode(op_desc, graph, kernel_json[kJsonKeyProcess]); - if (op_node == nullptr) { - MS_LOG(ERROR) << "Error decode cnode."; - return nullptr; - } + MS_EXCEPTION_IF_NULL(op_node); } MS_LOG(DEBUG) << "decode cnodes success."; @@ -359,22 +393,8 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel MS_LOG(ERROR) << "Error decode outputs, no outputs for graph."; return nullptr; } - std::vector outputs{NewValueNode(prim::kPrimMakeTuple)}; - for (const auto &output_desc : output_descs) { - std::string name = output_desc[kJsonKeyTensorName]; - if (nodes_map_.count(name) == 0) { - MS_LOG(ERROR) << "Output: " << name << " of graph not found."; - return nullptr; - } - outputs.push_back(nodes_map_[name]); - } - if (outputs.size() == 2) { - graph->set_output(outputs[1]); - } else { - auto output = graph->NewCNode(outputs); - graph->AddNode(output); - graph->set_output(output); - } + auto output = DecodeOutput(output_descs, graph); + MS_EXCEPTION_IF_NULL(output); MS_LOG(DEBUG) << "decode success, " << kernel_json; return graph; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h index 103454e678..ac3e9331d3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h @@ -37,11 +37,10 @@ class AkgKernelJsonDecoder { AnfNodePtrList *res_graphs); private: - ScalarPtr DecodeScalar(const nlohmann::json &scalar_json); - ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph); ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); - std::map nodes_map_{}; + AnfNodePtr DecodeOutput(const std::vector &output_descs, const FuncGraphPtr &func_graph); + std::map nodes_map_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc index d290606cac..840a7f13ff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.cc @@ -79,22 +79,16 @@ inline std::string AkgKernelJsonGenerator::GetOutputFormat(const AnfNodePtr &anf bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr &op_info, nlohmann::json *const inputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(inputs_json); - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. std::vector> inputs_ptr = op_info->inputs_ptr(); if (inputs_ptr.empty()) { - MS_LOG(DEBUG) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info"; - return true; + MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info"; + return false; } // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. auto dyn_input_sizes = GetDynInputSize(anf_node); - size_t real_input_index = 0; - std::vector input_list; for (size_t i = 0; i < inputs_ptr.size(); i++) { std::shared_ptr input_ptr = inputs_ptr[i]; if (input_ptr == nullptr) { @@ -102,10 +96,8 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con return false; } - auto op_input_name = input_ptr->name(); size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : IntToSize(dyn_input_sizes[i]); - - input_list.clear(); + std::vector input_list; for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { auto type_id = this->GetInputDataType(anf_node, real_input_index); std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); @@ -117,7 +109,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con nlohmann::json input_desc_json; input_desc_json[kJsonKeyDataType] = dtype; input_desc_json[kJsonKeyFormat] = this->GetInputFormat(anf_node, real_input_index); - input_desc_json[kJsonKeyName] = op_input_name; + input_desc_json[kJsonKeyName] = input_ptr->name(); input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); auto input_shape = this->GetInputShape(anf_node, real_input_index); if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && @@ -204,77 +196,56 @@ void AkgKernelJsonGenerator::GetJson(const AnfNodePtr &anf_node, const std::vect bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr &op_info, nlohmann::json *const attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(attrs_json); std::vector> attrs = op_info->attrs_ptr(); if (attrs.empty()) { - MS_LOG(INFO) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty"; + MS_LOG(DEBUG) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty"; return true; } - std::vector> inputs = op_info->inputs_ptr(); - - std::vector dyn_input_sizes; + auto dyn_input_sizes = GetDynInputSize(anf_node); auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - if (inputs.empty()) { - MS_LOG(ERROR) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info inputs is empty"; - return false; - } // create input name list for "x_shape" in attr with "x" in primitive. - std::map op_info_shape_name; - for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) { - std::string input_name = inputs[op_info_input_i]->name(); - std::string x_shape_name = input_name + "_shape"; - static_cast(op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name))); + std::vector> inputs = op_info->inputs_ptr(); + std::map op_info_shape_name; + for (size_t i = 0; i < inputs.size(); i++) { + op_info_shape_name[inputs[i]->name() + "_shape"] = i; } for (const auto &op_attr : attrs) { nlohmann::json attr_json; ValuePtr attr_value = primitive->GetAttr(op_attr->name()); if (attr_value == nullptr && op_attr->name() != kArgDataformat) { - if (op_attr->param_type() == "required") { - // match "x_shape" in att with "x" in primitive. - std::string attr_name = op_attr->name(); - auto find_item = std::find_if( - op_info_shape_name.begin(), op_info_shape_name.end(), - [attr_name](const std::map::value_type item) { return item.second == attr_name; }); - if (find_item != op_info_shape_name.end()) { - if (!dyn_input_sizes.empty()) { - if (find_item->first >= dyn_input_sizes.size() - 1) { - MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first - << " is out of range:" << dyn_input_sizes.size() - 1 << "."; - return false; - } - size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0)); - for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) { - attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx); - attr_json[kJsonKeyName] = op_attr->name(); - attrs_json->push_back(attr_json); - tensor_idx++; - } - } else { - attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first); + if (op_attr->param_type() != "required") continue; + // match "x_shape" in attr with "x" in primitive. + auto find_item = op_info_shape_name.find(op_attr->name()); + if (find_item != op_info_shape_name.end()) { + if (!dyn_input_sizes.empty()) { + if (find_item->second >= dyn_input_sizes.size() - 1) { + MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->second + << " is out of range:" << dyn_input_sizes.size() - 1 << "."; + return false; + } + size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->second], 0)); + for (int input_i = 0; input_i < dyn_input_sizes[find_item->second]; input_i++) { + attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx); attr_json[kJsonKeyName] = op_attr->name(); attrs_json->push_back(attr_json); + tensor_idx++; } } else { - MS_LOG(ERROR) << "op [" << anf_node->fullname_with_scope() << "] should have attr :" << op_attr->name(); - return false; + attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->second); + attr_json[kJsonKeyName] = op_attr->name(); + attrs_json->push_back(attr_json); } + } else { + MS_LOG(ERROR) << "op [" << anf_node->fullname_with_scope() << "] should have attr :" << op_attr->name(); + return false; } - continue; + } else { + GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value); + attr_json[kJsonKeyName] = op_attr->name(); + attrs_json->push_back(attr_json); } - - GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value); - - attr_json[kJsonKeyName] = op_attr->name(); - attrs_json->push_back(attr_json); } return true; } @@ -485,7 +456,47 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size() << "], output_list: [" << input_list.size() << "]."; std::map node_json_map; + if (!GenSingleJsons(anf_nodes, &node_json_map)) return false; + + UpdateTensorName(anf_nodes, &node_json_map); + + std::vector node_json_desc; + std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), + [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); + (*kernel_json)[kJsonKeyOpDesc] = node_json_desc; + + auto inputs_json = CreateInputsJson(anf_nodes, input_list, node_json_map); + (*kernel_json)[kJsonKeyInputDesc] = inputs_json; + (*kernel_json)[kJsonKeyOutputDesc] = + CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map); + size_t hash_id = std::hash()(kernel_json->dump()); + kernel_name_ = "Fused_"; + auto fg = anf_nodes[0]->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (attr_val != nullptr) { + auto fg_attr = GetValue(attr_val); + (void)kernel_name_.append(fg_attr).append("_"); + } + (void)kernel_name_.append(std::to_string(hash_id)); + (*kernel_json)[kJsonKeyId] = GetOpCntInc(); + (*kernel_json)[kJsonKeyOp] = kernel_name_; + (*kernel_json)[kJsonKeyPlatform] = "AKG"; + (*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]); + (*kernel_json)[kJsonKeyComposite] = true; + (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString(); + + if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { + MS_LOG(ERROR) << "Cal mem size failed."; + return false; + } + + return true; +} + +bool AkgKernelJsonGenerator::GenSingleJsons(const std::vector &anf_nodes, + std::map *node_json_map) { for (auto const &anf_node : anf_nodes) { MS_EXCEPTION_IF_NULL(anf_node); if (!AnfAlgo::IsRealKernel(anf_node)) { @@ -507,9 +518,13 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); } - node_json_map[anf_node] = node_json; + (*node_json_map)[anf_node] = node_json; } + return true; +} +void AkgKernelJsonGenerator::UpdateTensorName(const std::vector &anf_nodes, + std::map *node_json_map) { for (auto const &anf_node : anf_nodes) { auto dyn_input_sizes = GetDynInputSize(anf_node); bool is_dynamic_input = !dyn_input_sizes.empty(); @@ -519,11 +534,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; for (size_t j = 0; j < input_tensor_num; ++j) { auto tmp_input = GetKernelInput(anf_node, real_input_index); - std::string tensor_name = GetTensorName(node_json_map[anf_node], kJsonKeyInputDesc, std::make_pair(i, j)); - if (node_json_map.find(tmp_input.first) != node_json_map.end()) { + std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kJsonKeyInputDesc, std::make_pair(i, j)); + if (node_json_map->find(tmp_input.first) != node_json_map->end()) { std::string new_tensor_name = - GetTensorName(node_json_map[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second)); - SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node])); + GetTensorName((*node_json_map)[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second)); + SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node])); MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; @@ -535,12 +550,11 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf } } } +} - std::vector node_json_desc; - std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), - [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); - (*kernel_json)[kJsonKeyOpDesc] = node_json_desc; - +nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector &anf_nodes, + const std::vector &input_list, + const std::map &node_json_map) { nlohmann::json inputs_json; auto input_index = GetInputIndex(anf_nodes, input_list); for (size_t i = 0; i < input_index.size(); ++i) { @@ -549,18 +563,22 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); nlohmann::json input_desc_json; input_desc_json[kJsonKeyTensorName] = - GetTensorName(node_json_map[tmp_input.first], kJsonKeyInputDesc, tmp_input.second); + GetTensorName(node_json_map.at(tmp_input.first), kJsonKeyInputDesc, tmp_input.second); input_desc_json[kJsonKeyDataType] = dtype; input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first); input_desc_json[kJsonKeyShape] = this->GetInputShape(tmp_input.first, tmp_input.second.first); inputs_json.emplace_back(std::vector{input_desc_json}); } - (*kernel_json)[kJsonKeyInputDesc] = inputs_json; + return inputs_json; +} +nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector &anf_nodes, + const std::vector &input_list, + const std::vector &output_list, + const nlohmann::json &inputs_json, + const std::map &node_json_map) { nlohmann::json outputs_json; auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); - std::map> sub_graphs; - std::map dim_infos; for (size_t i = 0; i < output_index.size(); ++i) { auto tmp_output = output_index[i]; bool found = false; @@ -576,7 +594,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second); std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel); output_desc_json[kJsonKeyTensorName] = - GetTensorName(node_json_map[tmp_output.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second)); + GetTensorName(node_json_map.at(tmp_output.first), kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second)); output_desc_json[kJsonKeyDataType] = dtype; output_desc_json[kJsonKeyFormat] = this->GetOutputFormat(tmp_output.first, tmp_output.second); auto output_shape = this->GetOutputShape(tmp_output.first, tmp_output.second); @@ -587,33 +605,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector &anf } outputs_json.emplace_back(output_desc_json); } - (*kernel_json)[kJsonKeyOutputDesc] = outputs_json; - - auto processor = GetProcessorStr(anf_nodes[0]); - - size_t hash_id = std::hash()(kernel_json->dump()); - kernel_name_ = "Fused_"; - auto fg = anf_nodes[0]->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (attr_val != nullptr) { - auto fg_attr = GetValue(attr_val); - (void)kernel_name_.append(fg_attr).append("_"); - } - (void)kernel_name_.append(std::to_string(hash_id)); - (*kernel_json)[kJsonKeyId] = GetOpCntInc(); - (*kernel_json)[kJsonKeyOp] = kernel_name_; - (*kernel_json)[kJsonKeyPlatform] = "AKG"; - (*kernel_json)[kJsonKeyProcess] = processor; - (*kernel_json)[kJsonKeyComposite] = true; - (*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString(); - - if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) { - MS_LOG(ERROR) << "Cal mem size failed."; - return false; - } - - return true; + return outputs_json; } bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h index 95e6e13453..30ea40ea03 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h @@ -94,6 +94,14 @@ class AkgKernelJsonGenerator { nlohmann::json *const attrs_json); bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, std::vector *const output_size); + bool GenSingleJsons(const std::vector &anf_nodes, std::map *node_json_map); + void UpdateTensorName(const std::vector &anf_nodes, std::map *node_json_map); + nlohmann::json CreateInputsJson(const std::vector &anf_nodes, const std::vector &input_list, + const std::map &node_json_map); + nlohmann::json CreateOutputsJson(const std::vector &anf_nodes, const std::vector &input_list, + const std::vector &output_list, const nlohmann::json &inputs_json, + const std::map &node_json_map); + int GetOpCntInc(); size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); size_t GetOutputTensorIdxInc(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc index 98c6ce869c..0491f10655 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc @@ -36,15 +36,23 @@ namespace mindspore { namespace kernel { +namespace { constexpr int32_t PROCESS_NUM = 16; constexpr int32_t TIME_OUT = 300; -bool AkgAscendKernelBuilder::AkgOpParallelBuild( - const std::vector> &build_args) { +void SetKernelMod(const KernelPackPtr &kernel_pack, const AkgKernelJsonGenerator &json_generator, + const AnfNodePtr &anf_node) { + auto kernel_mod_ptr = std::make_shared(kernel_pack); + kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list()); + kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list()); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); +} +} // namespace + +std::vector AkgAscendKernelBuilder::GetNotCachedKernelJsons(const std::vector &build_args) { // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. std::vector jsons; std::unordered_set kernel_name_set; - std::vector> repeat_nodes; for (const auto &[json_generator, anf_node] : build_args) { MS_EXCEPTION_IF_NULL(anf_node); auto kernel_name = json_generator.kernel_name(); @@ -53,15 +61,12 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild( if (cached_kernel_pack != nullptr) { MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" << anf_node->fullname_with_scope() << "]."; - auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); - kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + SetKernelMod(cached_kernel_pack, json_generator, anf_node); continue; } if (kernel_name_set.count(kernel_name) != 0) { - repeat_nodes.push_back({json_generator, anf_node}); + repeat_nodes_.push_back({json_generator, anf_node}); continue; } kernel_name_set.insert(kernel_name); @@ -69,7 +74,43 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild( kernel::SaveJsonInfo(kernel_name, kernel_json); jsons.push_back(kernel_json); } + return jsons; +} + +bool AkgAscendKernelBuilder::InsertToCache(const std::vector &build_args) { + for (const auto &[json_generator, anf_node] : build_args) { + auto kernel_name = json_generator.kernel_name(); + auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node)); + if (new_kernel_pack == nullptr) { + MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + return false; + } + SetKernelMod(new_kernel_pack, json_generator, anf_node); + MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!"; + } + return true; +} + +bool AkgAscendKernelBuilder::HandleRepeatNodes() { + for (const auto &[json_generator, anf_node] : repeat_nodes_) { + auto kernel_name = json_generator.kernel_name(); + auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node)); + if (cached_kernel_pack == nullptr) { + MS_LOG(ERROR) << "Use cached kernel failed, kernel_name[" << kernel_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + return false; + } + MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + SetKernelMod(cached_kernel_pack, json_generator, anf_node); + } + return true; +} +bool AkgAscendKernelBuilder::AkgOpParallelBuild(const std::vector &build_args) { + repeat_nodes_.clear(); + auto jsons = GetNotCachedKernelJsons(build_args); if (jsons.empty()) { return true; } @@ -89,56 +130,35 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild( } // All unique done here, cache them and set kernel. - for (const auto &[json_generator, anf_node] : build_args) { - auto kernel_name = json_generator.kernel_name(); - auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node)); - if (new_kernel_pack == nullptr) { - MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - return false; - } - auto kernel_mod_ptr = std::make_shared(new_kernel_pack); - kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!"; + if (!InsertToCache(build_args)) { + MS_LOG(ERROR) << "Insert cache failed."; + return false; } - // Handle repeated nodes. - for (const auto &[json_generator, anf_node] : repeat_nodes) { - auto kernel_name = json_generator.kernel_name(); - auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node)); - if (cached_kernel_pack == nullptr) return false; - MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); - kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + if (!HandleRepeatNodes()) { + MS_LOG(ERROR) << "Handle repeat nodes failed."; + return false; } return true; } bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes) { - std::vector> json_and_node; + std::vector json_and_node; for (const auto &anf_node : anf_nodes) { MS_EXCEPTION_IF_NULL(anf_node); AkgKernelJsonGenerator akg_kernel_json_generator; - KernelPackPtr kernel_pack = nullptr; auto cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::IsGraphKernel(cnode)) { auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + MS_EXCEPTION_IF_NULL(func_graph); auto mng = func_graph->manager(); if (mng == nullptr) { mng = Manage(func_graph, true); func_graph->set_manager(mng); } - MS_EXCEPTION_IF_NULL(func_graph); - std::vector node_list; - std::vector input_list; - std::vector output_list; + std::vector node_list, input_list, output_list; MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]"; GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) { @@ -146,7 +166,7 @@ bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes) { } } else { if (!akg_kernel_json_generator.CollectJson(anf_node)) { - MS_EXCEPTION(UnknownError) << "Akg build failed op[" << anf_node->fullname_with_scope() << "]."; + MS_EXCEPTION(UnknownError) << "Akg build failed basic op[" << anf_node->fullname_with_scope() << "]."; } } json_and_node.push_back({akg_kernel_json_generator, anf_node}); diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h index c3d246e380..e250486623 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h @@ -27,12 +27,20 @@ namespace mindspore { namespace kernel { +using JsonNodePair = std::pair; + class AkgAscendKernelBuilder { public: AkgAscendKernelBuilder() = default; ~AkgAscendKernelBuilder() = default; + bool AkgOpParallelBuild(const std::vector &build_args); + + private: + std::vector GetNotCachedKernelJsons(const std::vector &build_args); + bool InsertToCache(const std::vector &build_args); + bool HandleRepeatNodes(); - bool AkgOpParallelBuild(const std::vector> &build_args); + std::vector repeat_nodes_; }; bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.h index b03933eabc..73613e92af 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.h @@ -1,4 +1,3 @@ - /** * Copyright 2020 Huawei Technologies Co., Ltd * @@ -32,7 +31,6 @@ class BasicOpsFusion : public Pass { bool Run(const FuncGraphPtr &func_graph) override; }; using FuseBasicPtr = std::shared_ptr; - } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index e75402a10d..6bae92054a 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -128,7 +128,7 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) { MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str; auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str); // parse result. - if (ret.is(py::none())) { + if (py::isinstance(ret)) { MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n" << node_desc_str; return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 89894fdc84..5189bb030e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -211,9 +211,9 @@ AnfNodePtr DeleteAttrInInput(const FuncGraphPtr &func_graph, const CNodePtr &cno return new_cnode; } -AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) { +AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { AnfNodePtrList outs; - auto out_node = (*fg)->output(); + auto out_node = fg->output(); if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { std::vector output_args; auto out_cnode = out_node->cast(); @@ -228,8 +228,8 @@ AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *m } } if (output_args.size() != out_cnode->inputs().size()) { - auto new_out = (*fg)->NewCNode(output_args); - (*mng)->Replace(out_node, new_out); + auto new_out = fg->NewCNode(output_args); + mng->Replace(out_node, new_out); } for (size_t i = 1; i < output_args.size(); ++i) { @@ -241,6 +241,27 @@ AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *m outs.push_back(out_node); return outs; } + +bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, + const DumpOption &dump_option, nlohmann::json *op_desc, + std::map *address_node_map) { + kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option); + if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) { + MS_LOG(ERROR) << "Collect json desc failed."; + return false; + } + + *op_desc = akg_kernel_json_generator.kernel_json(); + if (address_node_map != nullptr) { + *address_node_map = akg_kernel_json_generator.address_node_map(); + } + std::string fused_name; + std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) { + (void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_"); + }); + MS_LOG(INFO) << "Collect fusion json: " << fused_name; + return true; +} } // namespace void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, @@ -457,7 +478,7 @@ void FuseNodesToSubGraph(const std::vector &fuse_nodes, mng->Replace(n, out); } - EliminateMakeTuple(&fg, &mng); + EliminateMakeTuple(fg, mng); // set graphKernel attr std::string fuse_op_name = ""; for (auto &fuse_node : fuse_nodes) { @@ -476,50 +497,26 @@ void FuseNodesToSubGraph(const std::vector &fuse_nodes, fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); } -bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc, +bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, std::map *address_node_map) { MS_EXCEPTION_IF_NULL(op_desc); if (nodes.empty()) { MS_LOG(ERROR) << "Input nodes is empty."; return false; } - bool has_graph_kernel = - std::any_of(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) { return AnfAlgo::IsGraphKernel(node); }); + bool has_graph_kernel = std::any_of(nodes.begin(), nodes.end(), AnfAlgo::IsGraphKernel); bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1; - auto gen_json = [&dump_option, &op_desc, &address_node_map](const AnfNodePtrList &op_nodes, - const AnfNodePtrList &inputs, - const AnfNodePtrList &outputs) -> bool { - kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option); - if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) { - MS_LOG(ERROR) << "Collect json desc failed."; - return false; - } - - *op_desc = akg_kernel_json_generator.kernel_json(); - if (address_node_map != nullptr) { - *address_node_map = akg_kernel_json_generator.address_node_map(); - } - std::string fused_name; - std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) { - (void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_"); - }); - MS_LOG(INFO) << "Collect fusion json: " << fused_name; - return true; - }; - FuncGraphPtr fg; - AnfNodePtrList op_nodes; - AnfNodePtrList inputs; - AnfNodePtrList outputs; + AnfNodePtrList op_nodes, inputs, outputs; if (is_single_graph_kernel) { fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]); kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); - return gen_json(op_nodes, inputs, outputs); + return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); } else if (!has_graph_kernel) { std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); op_nodes = nodes; - return gen_json(op_nodes, inputs, outputs); + return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); } std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); @@ -540,10 +537,10 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann inputs.clear(); outputs.clear(); kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); - return gen_json(op_nodes, inputs, outputs); + return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); } -bool AnfToJsonDesc(const std::vector &graphs, DumpOption dump_option, nlohmann::json *op_desc) { +bool AnfToJsonDesc(const std::vector &graphs, const DumpOption &dump_option, nlohmann::json *op_desc) { MS_EXCEPTION_IF_NULL(op_desc); std::vector graphs_desc; for (auto const &graph_nodes : graphs) { diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 0af501838d..1bd74664fc 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -46,9 +46,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new void FuseNodesToSubGraph(const std::vector &fuse_nodes, const std::shared_ptr &kernel_graph, const std::string &postfix, bool is_before_kernel_select); -bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc, +bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, std::map *address_node_map = nullptr); -bool AnfToJsonDesc(const std::vector &graphs, DumpOption dump_option, nlohmann::json *op_desc); +bool AnfToJsonDesc(const std::vector &graphs, const DumpOption &dump_option, nlohmann::json *op_desc); FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector &inputs); bool JsonDescToAnf(const std::string &json_desc, const std::map &address_node_map, std::vector *res_graphs); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc index 9191602048..ac000d0739 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc @@ -57,8 +57,6 @@ inline void TraverseFuncGraph(const FuncGraphPtr &root, std::functionget_return(), callback); } -class AreaGraph; -class Splitter; class Area { public: explicit Area(const AnfNodePtrList &anf_arr) { @@ -73,6 +71,8 @@ class Area { } } + ~Area() = default; + // Set the external inputs of spy as a Parameter. void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map *param_node_map) { std::unordered_map node_param_map; @@ -148,8 +148,8 @@ class Area { } } - friend AreaGraph; - friend Splitter; + const std::unordered_set &nodes() const { return nodes_; } + const std::vector &spy_cnodes() const { return spy_cnodes_; } private: // This is a CNode that does not belong to this area. @@ -170,9 +170,8 @@ class AreaGraph { // Build an area graph to maintain the relation between areas. // Input node_groups: A group list, each element is a AnfNode list representing the node set in this group. static AreaGraphPtr BuildAreaGraph(const std::vector &node_groups) { - AreaGraph *area_graph_ptr = new (std::nothrow) AreaGraph(node_groups); - if (!area_graph_ptr) return nullptr; - auto area_graph = AreaGraphPtr(area_graph_ptr); + auto area_graph = AreaGraphPtr(new AreaGraph(node_groups)); + if (area_graph == nullptr) return nullptr; if (!area_graph->TopoSort()) { MS_LOG(WARNING) << "The groups have a cycle."; return nullptr; @@ -184,12 +183,12 @@ class AreaGraph { // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs. // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting. void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector *main_cnodes, - std::vector *cnode_group_id, std::function expand_callback) { + std::vector *cnode_group_id, std::function expand_callback) { main_cnodes->clear(); main_cnodes->resize(areas_.size(), nullptr); for (auto &area : this->areas_) { - expand_callback(&area); + expand_callback(area); } for (auto index : topo_order_) { @@ -208,6 +207,8 @@ class AreaGraph { return; } + ~AreaGraph() = default; + private: explicit AreaGraph(const std::vector &node_groups) : edge_prev_(node_groups.size()) { for (size_t i = 0; i < node_groups.size(); ++i) { @@ -217,7 +218,7 @@ class AreaGraph { } } for (auto &area : areas_) { - for (auto &spy : area.spy_cnodes_) { + for (auto &spy : area.spy_cnodes()) { auto cnode = spy->cast(); MS_EXCEPTION_IF_NULL(cnode); size_t v = node_area_map_[spy]; @@ -333,8 +334,8 @@ class Splitter { // The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan. std::vector cnodes_group_id; - std::function expand_callback = std::bind(&Splitter::AreaExpand, this, std::placeholders::_1); - area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id, expand_callback); + area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id, + [this](const Area &area) { this->AreaExpand(area); }); RebuildGraph(cnodes_group_id); @@ -348,6 +349,8 @@ class Splitter { return SplitterPtr(new Splitter(main_cnode, split_schemer)); } + ~Splitter() = default; + private: Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {} @@ -479,9 +482,9 @@ class Splitter { } // Copy all Parameter and ValueNode that the area used. - void AreaExpand(Area *area) { + void AreaExpand(const Area &area) { std::unordered_map old_valuenode_and_param_map; - for (auto sub_node : area->nodes_) { + for (auto sub_node : area.nodes()) { auto sub_cnode = sub_node->cast(); if (sub_cnode == nullptr) continue; for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) { @@ -565,7 +568,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { auto json_desc_str = json_desc.dump(); MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json:\n" << json_desc_str; auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str); - if (ret.is(py::none())) { + if (py::isinstance(ret)) { MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n" << json_desc_str; return false; diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index f316a1c61f..9c4f51ad76 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -463,13 +463,15 @@ std::string AddGlobalId(const std::string &filename) { static size_t g_id = 0; std::ostringstream s; auto i = filename.rfind('/'); - if (i == string::npos) { + if (i >= filename.size()) { s << std::setfill('0') << std::setw(4) << g_id << "_"; s << filename; } else { s << filename.substr(0, i + 1); s << std::setfill('0') << std::setw(4) << g_id << "_"; - s << filename.substr(i + 1); + if (i + 1 < filename.size()) { + s << filename.substr(i + 1); + } } ++g_id; return s.str(); diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index c4f67608d3..81832f64fd 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -236,12 +236,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector node_list; - std::vector input_list; - std::vector output_list; + std::vector node_list, input_list, output_list; kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); std::vector graph_input_format; @@ -295,6 +290,22 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_gr AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); SetTensorDeviceInfo(*graph_selected_info, kernel_node); } + +void PrintUnsupportedTypeException(const CNodePtr &kernel_node, const std::vector &inputs_type, + const std::vector &outputs_type) { + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + std::string build_type = "in ["; + std::for_each(std::begin(inputs_type), std::end(inputs_type), + [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); + build_type += "] out ["; + std::for_each(std::begin(outputs_type), std::end(outputs_type), + [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); + build_type += "]"; + auto supported_type_lists = SupportedTypeList(kernel_node); + MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name + << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists + << ", but get " << build_type; +} } // namespace void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr &kernel_graph) { @@ -329,7 +340,7 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr(kernel_node->input(kAnfPrimitiveIndex)); + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(kernel_node); MS_EXCEPTION_IF_NULL(func_graph); SetGraphKernelInfo(kernel_node, func_graph); return; @@ -351,8 +362,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) { UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); } - std::shared_ptr builder = - std::make_shared(); + auto builder = std::make_shared(); builder->SetOriginDataFormat(origin_data_format); builder->SetInputsFormat(inputs_format); builder->SetInputsDeviceType(inputs_type); @@ -360,35 +370,23 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { builder->SetOutputsDeviceType(outputs_type); bool result = false; - KernelType res_kernel_type = UNKNOWN_KERNEL_TYPE; if (kernel_type == UNKNOWN_KERNEL_TYPE) { result = kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); if (!result) { result = SelectAkgKernel(kernel_node, builder->Build()); - res_kernel_type = AKG_KERNEL; + kernel_type = AKG_KERNEL; } } else if (kernel_type == AKG_KERNEL) { result = SelectAkgKernel(kernel_node, builder->Build()); - res_kernel_type = AKG_KERNEL; } if (!result) { - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - std::string build_type = "in ["; - std::for_each(std::begin(inputs_type), std::end(inputs_type), - [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); - build_type += "] out ["; - std::for_each(std::begin(outputs_type), std::end(outputs_type), - [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); - build_type += "]"; - auto supported_type_lists = SupportedTypeList(kernel_node); - MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name - << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists - << ", but get " << build_type; - } - builder->SetKernelType(res_kernel_type); + PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type); + return; + } + builder->SetKernelType(kernel_type); builder->SetProcessor(kernel::Processor::CUDA); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); SetTensorDeviceInfo(*(builder->Build()), kernel_node);