From 341200ab97fc4c5240d3ac70c4c1e4043b3d3be3 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Sat, 29 Aug 2020 16:44:51 +0800 Subject: [PATCH] gpu kernel_info_setter code review --- .../ccsrc/backend/session/gpu_session.cc | 14 +++++++------- mindspore/ccsrc/backend/session/gpu_session.h | 2 +- .../runtime/device/gpu/gpu_kernel_runtime.cc | 4 +++- .../runtime/device/gpu/kernel_info_setter.cc | 19 ++++++++++++++----- .../runtime/device/gpu/kernel_info_setter.h | 2 +- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 85d042746d..41493f2f64 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -49,10 +49,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; void GPUSession::SelectKernel(const std::shared_ptr &kernel_graph) const { MS_EXCEPTION_IF_NULL(kernel_graph); - bool in_black_list = CheckInModeBlackList(kernel_graph); + bool graph_format_transform = IsSupportFormatTransform(kernel_graph); for (const auto &kernel_node : kernel_graph->execution_order()) { MS_EXCEPTION_IF_NULL(kernel_node); - device::gpu::SetKernelInfo(kernel_node, in_black_list); + device::gpu::SetKernelInfo(kernel_node, graph_format_transform); } } @@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { + if (IsSupportFormatTransform(kernel_graph) && context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -193,14 +193,14 @@ void GPUSession::Execute(const std::shared_ptr &kernel_graph) const } } -bool GPUSession::CheckInModeBlackList(const std::shared_ptr &kernel_graph) const { +bool GPUSession::IsSupportFormatTransform(const std::shared_ptr &kernel_graph) const { auto kernels = kernel_graph->execution_order(); size_t conv_cnt = 0; size_t bn_cnt = 0; for (const auto &kernel : kernels) { auto kernel_name = AnfAlgo::GetCNodeName(kernel); if (kernel_name == prim::kPrimLayerNorm->name()) { - return true; + return false; } if (kernel_name == prim::kPrimConv2D->name()) { conv_cnt++; @@ -210,9 +210,9 @@ bool GPUSession::CheckInModeBlackList(const std::shared_ptr &kernel } } if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) { - return true; + return false; } - return false; + return true; } GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index bd28e97b08..0a43472f4f 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -67,7 +67,7 @@ class GPUSession : public SessionBasic { void Execute(const std::shared_ptr &kernel_graph) const; - bool CheckInModeBlackList(const std::shared_ptr &kernel_graph) const; + bool IsSupportFormatTransform(const std::shared_ptr &kernel_graph) const; #ifdef ENABLE_DEBUGGER void Dump(const std::shared_ptr &kernel_graph) const; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 4fae7d9058..cc6fa74b9f 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -404,7 +404,9 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v // Release the kernel resource. for (const auto &kernel : execution_order) { auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); + if (kernel_mod == nullptr) { + continue; + } kernel_mod->ReleaseResource(); } } diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index 1b16fafa03..ec7f4e4906 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -176,9 +176,18 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vectorsecond.first; + // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. + if (inputs_format_position.size() == 0) { + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); input_index++) { + inputs_format_position.push_back(input_index); + } + } + for (const auto &input_format_position : inputs_format_position) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position); + if (input_shape.size() != 4) { + return false; + } } return true; } @@ -223,7 +232,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector inputs_format; std::vector inputs_type; for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { @@ -237,7 +246,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); } std::string origin_data_format = kOpFormat_DEFAULT; - if (!in_black_list && IsNeedProcessFormatInfo(kernel_node, inputs_type)) { + if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) { UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); } std::shared_ptr builder = diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h index 0f64527b4d..48f03ce5c4 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h @@ -53,7 +53,7 @@ static std::map, std::vector> {prim::kPrimAddN->name(), {{}, {0}}}, }; -void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list = false); +void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform = false); class KernelAttr { public: