diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 4563512c1d..15401a3f27 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -425,7 +425,7 @@ std::shared_ptr ChooseMatchedKernelInfo( return kernel_info_list[selected_index]; } -std::vector> GetAllMatchedFilteredKernelInfo( +std::vector> FilteredKernelInfoByDtype( const CNodePtr &cnode, const std::vector> &kernel_info_list) { std::vector> result; for (const auto &kernel_build_info : kernel_info_list) { @@ -474,7 +474,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, std::shared_ptr selected_kernel_info = nullptr; // Matched kernel info // Filter kernel info matched with me infered type - auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list); + auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list); if (!filtered_kernel_info_list.empty()) { selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); select_status = kStatusAllMatched; @@ -508,6 +508,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; kernel::AICpuQuery(kernel_node, &kernel_info_list); select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); } // The kernel info not finded both in the aicpu kernel list & aicore kernel list if (select_status == kNoMatched) { diff --git a/mindspore/ccsrc/kernel/kernel.h b/mindspore/ccsrc/kernel/kernel.h index 80d831269c..4217b56625 100644 --- a/mindspore/ccsrc/kernel/kernel.h +++ b/mindspore/ccsrc/kernel/kernel.h @@ -47,6 +47,13 @@ enum FusionType { OPAQUE, UNKNOWN_FUSION_TYPE = -1, }; +enum OpPattern { + kCommonPattern = 0, + kFormatAgnosticPattern = 1, + kBroadcastPattern = 2, + kReducePattern = 3, + kDynamicFormatPattern = 4, +}; // Backend processor enum Processor { diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index df855f5340..9c0272dd7a 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -162,5 +162,10 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->output_reshape_type_ = output_reshape_type; } + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->op_pattern_ = pattern; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h index 779be057f6..d17b41a6fc 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ b/mindspore/ccsrc/kernel/kernel_build_info.h @@ -34,6 +34,7 @@ class KernelBuildInfo { kernel_type_ = AUTO_DIFF_KERNEL; fusion_type_ = OPAQUE; processor_ = AICORE; + op_pattern_ = kCommonPattern; input_reshape_type_ = {}; output_reshape_type_ = {}; inputs_format_ = {}; @@ -70,6 +71,8 @@ class KernelBuildInfo { std::vector GetAllOutputDeviceTypes() const; + OpPattern op_pattern() const { return op_pattern_; } + FusionType fusion_type() const { return fusion_type_; } Processor processor() const { return processor_; } @@ -88,6 +91,7 @@ class KernelBuildInfo { private: KernelType kernel_type_; std::vector inputs_format_; + OpPattern op_pattern_; std::vector outputs_format_; std::vector> input_reshape_type_; std::vector> output_reshape_type_; @@ -125,6 +129,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder { void SetProcessor(Processor processor); + void SetOpPattern(OpPattern pattern); + std::shared_ptr Build(); private: diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 8cdf91fd9f..c1bfafc384 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -40,7 +40,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); } else { MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; - for (size_t index; index < kernel_info_list->size(); ++index) { + for (size_t index = 0; index < kernel_info_list->size(); ++index) { MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); } diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h index 670830a8b1..8d7b543ea6 100644 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ b/mindspore/ccsrc/kernel/oplib/opinfo.h @@ -21,6 +21,7 @@ #include #include #include "ir/dtype.h" +#include "kernel/kernel.h" namespace mindspore { namespace kernel { @@ -100,7 +101,7 @@ class OpInfo { std::string kernel_name() const { return kernel_name_; } bool partial_flag() const { return partial_flag_; } bool dynamic_format() const { return dynamic_format_; } - std::string op_pattern() const { return op_pattern_; } + OpPattern op_pattern() const { return op_pattern_; } std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } std::vector> outputs_ptr() const { return outputs_ptr_; } @@ -116,7 +117,7 @@ class OpInfo { void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } - void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } + void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } @@ -137,7 +138,7 @@ class OpInfo { std::string kernel_name_; bool partial_flag_ = false; bool dynamic_format_ = false; - std::string op_pattern_; + OpPattern op_pattern_ = kCommonPattern; std::vector> attrs_ptr_; std::vector> inputs_ptr_; std::vector> outputs_ptr_; diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index f5f2e1601b..b1bff36518 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "utils/log_adapter.h" #include "utils/overload.h" #include "utils/context/ms_context.h" @@ -35,6 +36,9 @@ constexpr auto kPartialFlag = "partial_flag"; constexpr auto kReshapeType = "reshape_type"; constexpr auto kOpPattern = "op_pattern"; constexpr auto kDynamicFormat = "dynamic_format"; +constexpr auto kFormatAgnostic = "formatAgnostic"; +constexpr auto kBroadcast = "broadcast"; +constexpr auto kReduce = "reduce"; constexpr auto kDtypeFormat = "dtype_format"; constexpr auto kAttr = "attr"; constexpr auto kIputs = "inputs"; @@ -95,13 +99,19 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) } void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { + const std::map kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, + {kFormatAgnostic, kBroadcastPattern}, + {kReduce, kReducePattern}, + {kDynamicFormat, kDynamicFormatPattern}}; op_info->set_async_flag(obj.at(kAsyncFlag)); op_info->set_binfile_name(obj.at(kBinfileName)); op_info->set_compute_cost(obj.at(kComputeCost)); op_info->set_kernel_name(obj.at(kKernelName)); op_info->set_partial_flag(obj.at(kPartialFlag)); if (obj.find(kOpPattern) != obj.end()) { - op_info->set_op_pattern(obj.at(kOpPattern)); + if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) { + op_info->set_op_pattern(obj.at(kOpPattern)); + } } if (obj.find(kDynamicFormat) != obj.end()) { op_info->set_dynamic_format(obj.at(kDynamicFormat)); diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index 33743b3175..8b1a1548bc 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -492,6 +492,7 @@ void SetKernelBuildCommonInfo(const std::shared_ptrSetFusionType(tbe::GetFusionType(fusion_type)); } + builder->SetOpPattern(op_info_ptr->op_pattern()); builder->SetKernelType(TBE_KERNEL); } @@ -509,7 +510,7 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptrGetAttr("dyn_input_sizes") != nullptr) { dyn_input_sizes = GetValue>(primitive->GetAttr("dyn_input_sizes")); } - if (inputs.size() > 0) { + if (!inputs.empty()) { MS_EXCEPTION_IF_NULL(inputs[0]); size_t kernel_info_cnt = inputs[0]->dtypes().size(); for (size_t j = 0; j < kernel_info_cnt; j++) { @@ -624,21 +625,17 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vectorexecution_mode() == kPynativeMode) { - kernel_info_list->push_back(parse_info); - } else { - if (IsValidKernelInfo(kernel_node, *(parse_info))) { - if (CheckSupported(kernel_node, parse_info)) { - kernel_info_list->push_back(parse_info); - } else { - MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; - } + for (const auto &parse_info : parse_info_list) { + if (IsValidKernelInfo(kernel_node, *(parse_info))) { + if (CheckSupported(kernel_node, parse_info)) { + kernel_info_list->push_back(parse_info); + } else { + MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; } } - } - if (kernel_info_list->empty()) { - MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; + if (kernel_info_list->empty()) { + MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; + } } } } // namespace kernel diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc index 120462fd53..5b5bf7e4fc 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc @@ -44,6 +44,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph auto builder = std::make_shared(kernel_builder_info); builder->SetKernelType(AICPU_KERNEL); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), node); } else { MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" << node->DebugString() << "]"; diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index f141ebce6b..8b45ad7d32 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -653,6 +653,16 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_ to_node->set_abstract(from_node->abstract()); } +kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + // select_kernel_build_info() has checked whether return pointer is null + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->op_pattern(); +} + // get KernelBuildType of node, such as ATT,RT,FWK and so on KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 76b57482dd..be88075f4b 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -138,6 +138,8 @@ class AnfRuntimeAlgorithm { static void SetOutputInferTypeAndShape(const std::vector &types, const std::vector> &shapes, AnfNode *node); static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); + // get op pattern of the node + static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); // get KernelBuildType of node ,such as ATT,RT,FWK and so on static KernelType GetKernelType(const AnfNodePtr &node); // get processor type:AICORE,AICPU... diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 409c4b726d..5b8a8b178e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto"; // attr key name constexpr auto kAttrInputNames = "input_names"; +constexpr auto kAttrIsAICPUKernel = "is_ai_cpu_kernel"; constexpr auto kIsBackendCast = "is_backed_cast"; constexpr auto kAttrOutputNames = "output_names"; constexpr auto kAttrVisited = "visited"; @@ -215,10 +216,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; -const std::set kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, - kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, - kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, - kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; +constexpr auto kOpFormat_NDHWC = "NDHWC"; +const std::set kOpFormatList = { + kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, + kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, + kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC}; const std::set kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; const std::set kOptOperatorSet = { kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName,