From d96044fbd9fbe13d1d77b4c2e8bf8b917cc7be55 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Thu, 24 Sep 2020 16:27:45 +0800 Subject: [PATCH] Add max match rule for cpu kernel selection --- .../runtime/device/cpu/kernel_select_cpu.cc | 92 ++++++++++++------- 1 file changed, 61 insertions(+), 31 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc index b9496318dc..192c4ba51d 100644 --- a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc +++ b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc @@ -78,33 +78,40 @@ void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &ke } } -bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector &input_formats, - const std::vector &input_types, - const std::vector &input_not_cnode_indexes) { +std::pair GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, + const std::vector &input_formats, + const std::vector &input_types, + const std::vector &input_not_cnode_indexes) { if (kernel_attr.GetInputSize() != input_types.size()) { MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); - return false; + return std::make_pair(0, 0); } + int data_type_matched_num = 0; + int format_matched_num = 0; auto input_num = input_types.size(); for (size_t i = 0; i < input_num; ++i) { bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), [i](size_t index) { return index == i; }); bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); if (have_cnode_input && is_not_cnode_idx) { + data_type_matched_num++; + format_matched_num++; continue; } if (kernel_attr.GetInputAttr(i).first != input_types[i]) { MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first << ", actual input dtype:" << input_types[i]; - return false; + } else { + data_type_matched_num++; } if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second << ", actual input format:" << input_formats[i]; - return false; + } else { + format_matched_num++; } } - return true; + return std::make_pair(data_type_matched_num, format_matched_num); } void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { @@ -121,6 +128,18 @@ void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { kernel_attr->AddOutputAttr(output_dtype); } } + +void SetKernelBuildInfo(const std::vector &input_formats, const std::vector &input_types, + const std::vector &output_formats, const std::vector &output_types, + AnfNode *kernel_node) { + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + builder->SetInputsFormat(input_formats); + builder->SetInputsDeviceType(input_types); + builder->SetOutputsFormat(output_formats); + builder->SetOutputsDeviceType(output_types); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); +} } // namespace void SetKernelInfo(const CNodePtr &kernel_node) { @@ -136,38 +155,49 @@ void SetKernelInfo(const CNodePtr &kernel_node) { auto kernel_attrs = kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); - for (size_t index = 0; index < kernel_attrs.size(); ++index) { - auto kernel_attr = kernel_attrs[index]; + int max_type_matched_num = -1; + int max_format_matched_num = -1; + KernelAttr selected_kernel_attr; + for (auto kernel_attr : kernel_attrs) { if (kernel_attr.GetAllSame()) { ExpandKernelAttr(kernel_node, &kernel_attr); } - bool ignore_check = false; - if (index == kernel_attrs.size() - 1 && input_types.size() == input_not_cnode_indexes.size()) { - ignore_check = true; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (kernel_attr.GetOutputSize() != output_num) { + MS_LOG(DEBUG) << "Output num is not equal!"; + continue; + } + std::pair input_type_format_matched_num = + GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes); + // Data type first + if (input_type_format_matched_num.first > max_type_matched_num) { + max_type_matched_num = input_type_format_matched_num.first; + max_format_matched_num = input_type_format_matched_num.second; + selected_kernel_attr = kernel_attr; + } else if (input_type_format_matched_num.first == max_type_matched_num && + input_type_format_matched_num.second > max_format_matched_num) { + max_format_matched_num = input_type_format_matched_num.second; + selected_kernel_attr = kernel_attr; } - if (ignore_check || IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (kernel_attr.GetOutputSize() != output_num) { - MS_LOG(DEBUG) << "Output num is not equal!"; - continue; - } - MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; - GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); - UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); - for (auto &input_index : input_not_cnode_indexes) { - input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; - } + // All formats and data types matched + if (max_type_matched_num == SizeToInt(input_types.size()) && + max_format_matched_num == SizeToInt(input_types.size())) { break; } } - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - builder->SetInputsFormat(input_formats); - builder->SetInputsDeviceType(input_types); - builder->SetOutputsFormat(output_formats); - builder->SetOutputsDeviceType(output_types); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); + if ((max_type_matched_num == SizeToInt(input_types.size()) && + max_format_matched_num == SizeToInt(input_types.size())) || + input_types.size() == input_not_cnode_indexes.size()) { + MS_LOG(INFO) << "Input format and dtype is matched, max_type_matched_num: " << max_type_matched_num + << ", max_format_matched_num: " << max_format_matched_num; + GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types); + UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node); + for (auto &input_index : input_not_cnode_indexes) { + input_types[input_index] = selected_kernel_attr.GetInputAttr(input_index).first; + } + } + SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get()); } } // namespace cpu } // namespace device