From bcef06b41b12d7129a0f4a12ac81e2352eec508a Mon Sep 17 00:00:00 2001 From: WangChengke Date: Tue, 23 Jun 2020 11:54:11 +0800 Subject: [PATCH] modify code review --- .../ascend/kernel_select_graph_kernel.cc | 111 ++++++++++-------- 1 file changed, 63 insertions(+), 48 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc b/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc index b57ed1cd1b..db31460d31 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc @@ -25,6 +25,18 @@ namespace mindspore { namespace device { namespace ascend { +namespace { +// sort format according the number of occurrences. +bool cmp_format_num(const std::pair &a, const std::pair &b) { + if (a.second != b.second) { + return a.second > b.second; + } else if (a.first == kOpFormat_DEFAULT) { + return a.second + 1 > b.second; + } else if (b.first == kOpFormat_DEFAULT) { + return a.second > b.second + 1; + } + return a.second > b.second; +} TypeId GetPrimitivePrecision(const CNodePtr &cnode) { auto primitive = AnfAlgo::GetCNodePrimitive(cnode); @@ -44,6 +56,7 @@ TypeId GetPrimitivePrecision(const CNodePtr &cnode) { return except_type; } +} // namespace void ResetKernelBuildInfo(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); @@ -185,15 +198,12 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; MS_EXCEPTION_IF_NULL(input_kernel_node); if (!input_kernel_node->isa()) { - auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i); - ++all_input_formats[pre_format]; + ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)]; continue; } auto para = input_kernel_node->cast(); - MS_EXCEPTION_IF_NULL(para); if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { - auto pre_format = AnfAlgo::GetOutputFormat(para, 0); - ++all_input_formats[pre_format]; + ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)]; continue; } *use_same_format = false; @@ -207,17 +217,8 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { pairs.push_back(std::make_pair(iter->first, iter->second)); } - auto cmp_func = [](const std::pair &a, const std::pair &b) { - if (a.second != b.second) { - return a.second > b.second; - } else if (a.first == kOpFormat_DEFAULT) { - return a.second + 1 > b.second; - } else if (b.first == kOpFormat_DEFAULT) { - return a.second > b.second + 1; - } - return a.second > b.second; - }; - std::sort(pairs.begin(), pairs.end(), cmp_func); + + std::sort(pairs.begin(), pairs.end(), cmp_format_num); *default_format = pairs.begin()->first; } @@ -237,10 +238,9 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, } } -void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector &input_list, - const std::string &default_format, bool use_same_format, - std::vector *graph_input_format, - std::vector *graph_input_type) { +void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector &input_list, + const std::string &default_format, bool use_same_format, + std::vector *graph_input_format, std::vector *graph_input_type) { MS_EXCEPTION_IF_NULL(graph_input_format); MS_EXCEPTION_IF_NULL(graph_input_type); // We set same format to all inputs of graph kernel subgraph, and process this latter. @@ -338,21 +338,22 @@ void UpdateEquivFormat(const std::vector> &output_ } } -void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &node_list, - const std::vector &input_list, const FuncGraphManagerPtr &mng, - const std::string &default_format, std::vector *graph_input_format, - std::vector *graph_input_type) { +void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &input_list, + const FuncGraphManagerPtr &mng, const std::string &default_format, + std::vector *graph_input_format, std::vector *graph_input_type, + std::vector *need_update) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(graph_input_format); MS_EXCEPTION_IF_NULL(graph_input_type); - // update graph input format and dtype use inner ops. + MS_EXCEPTION_IF_NULL(need_update); + // check graph input format and dtype use inner ops. size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (graph_input_format->size() != input_num) { + if (graph_input_format->size() != input_num || graph_input_type->size() != input_num || + need_update->size() != input_num) { MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() - << "], [%" << graph_input_format->size() << "] != [%" << input_num << "]"; + << "], [" << graph_input_format->size() << "] != [" << input_num << "]"; } - std::vector need_update(input_num, false); auto &node_users = mng->node_users(); for (size_t i = 0; i < input_num; ++i) { auto &input = input_list[i]; @@ -372,36 +373,48 @@ void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vectorDebugString() << "] selected different format. we use defult: " << default_format; (*graph_input_format)[i] = default_format; - need_update[i] = true; + (*need_update)[i] = true; } - if (kernel_node->input(i + 1)->isa()) { - auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)); - if (user_dtype != (*graph_input_type)[i]) { - TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); - MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" - << kernel_node->DebugString() - << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); - (*graph_input_type)[i] = default_dtype; - need_update[i] = true; - } + if (kernel_node->input(i + 1)->isa() || + AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) { + continue; } + + TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); + MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" + << kernel_node->DebugString() + << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); + (*graph_input_type)[i] = default_dtype; + (*need_update)[i] = true; } } +} +void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &node_list, + const std::vector &input_list, const std::vector &need_update, + const std::vector &graph_input_format, + const std::vector &graph_input_type) { + MS_EXCEPTION_IF_NULL(kernel_node); + // update graph input format and dtype use inner ops. + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (graph_input_format.size() != input_num || graph_input_type.size() != input_num || + need_update.size() != input_num) { + MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() + << "], [" << graph_input_format.size() << "] != [" << input_num << "]"; + } for (size_t i = 0; i < input_num; ++i) { if (!need_update[i]) { continue; } - need_update[i] = false; MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() - << "] to: " << (*graph_input_format)[i]; + << "] to: " << graph_input_format[i]; MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() - << "] to: " << TypeIdLabel((*graph_input_type)[i]); + << "] to: " << TypeIdLabel(graph_input_type[i]); kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - std::vector outputs_format = {(*graph_input_format)[i]}; - std::vector outputs_device_type = {(*graph_input_type)[i]}; + std::vector outputs_format = {graph_input_format[i]}; + std::vector outputs_device_type = {graph_input_type[i]}; builder.SetOutputsFormat(outputs_format); builder.SetOutputsDeviceType(outputs_device_type); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); @@ -487,8 +500,8 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func std::vector graph_input_format; std::vector graph_input_type; - UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, - &graph_input_type); + UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, + &graph_input_type); auto mng = func_graph->manager(); if (mng == nullptr) { @@ -502,8 +515,10 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); // update graph input format and dtype use inner ops. - UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format, - &graph_input_type); + std::vector need_update(AnfAlgo::GetInputTensorNum(kernel_node), false); + CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type, + &need_update); + UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type); // set fix_precision for kernel when the me prim has fix_precision attr UpdateKernelInfo(node_list);