From 0363a7401a3a9c5df11c4471cef6d7d263db72d0 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Wed, 2 Dec 2020 18:44:29 +0800 Subject: [PATCH] infer shape for noop --- .../device/ascend/ascend_kernel_runtime.cc | 38 ------------------ .../runtime/device/executor/dynamic_kernel.cc | 40 +++++++++++++++++++ .../runtime/device/executor/dynamic_kernel.h | 2 + 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index e8bcead66c..46d15c5d8d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -41,7 +41,6 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/ascend/profiling/profiling_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" -#include "backend/optimizer/common/helper.h" #include "runtime/device/ascend/ascend_memory_manager.h" #include "debug/tensor_load.h" #include "debug/data_dump/dump_json_parser.h" @@ -114,34 +113,6 @@ std::string GetRankId() { } return rank_id_str; } - -void InferShapeForNopNode(AnfNodePtr *input_node) { - MS_EXCEPTION_IF_NULL(*input_node); - if (!opt::IsNopNode(*input_node)) { - MS_LOG(INFO) << "Input node is not a nop node, no need infer."; - return; - } - MS_LOG(INFO) << "Infer shape for nop node."; - std::stack nop_road; - nop_road.push(*input_node); - - while (true) { - auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0); - auto in_node = input_node_with_idx.first; - MS_EXCEPTION_IF_NULL(in_node); - if (opt::IsNopNode(in_node)) { - nop_road.push(in_node); - *input_node = in_node; - } else { - break; - } - } - while (!nop_road.empty()) { - auto nop_node = nop_road.top(); - AnfAlgo::InferShape(nop_node->cast()); - nop_road.pop(); - } -} } // namespace std::vector AscendKernelRuntime::exception_infoes_; @@ -665,15 +636,6 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap } if (dynamic_kernel->is_dynamic_shape()) { - auto kernel_node = dynamic_kernel->kernel_node(); - MS_EXCEPTION_IF_NULL(kernel_node); - auto input_size = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_size; i++) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i); - auto input_node = input_node_with_index.first; - MS_EXCEPTION_IF_NULL(input_node); - InferShapeForNopNode(&input_node); - } dynamic_kernel->InferShape(); dynamic_kernel->UpdateArgs(); } diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc index ec009b5604..218bc995bd 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc @@ -18,6 +18,7 @@ #include #include #include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" #include "common/trans.h" #include "pipeline/jit/static_analysis/static_analysis.h" #include "abstract/dshape.h" @@ -73,6 +74,7 @@ void DynamicKernel::InferShape() { } MS_EXCEPTION_IF_NULL(cnode_ptr_); MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope(); + InferShapeRecursive(); auto inputs = cnode_ptr_->inputs(); if (inputs.empty()) { @@ -124,5 +126,43 @@ void DynamicKernel::InferShape() { auto eval_result = abstract::CppInferShape(primitive, args_spec_list); cnode_ptr_->set_abstract(eval_result); } + +void DynamicKernel::InferShapeRecursive() { + auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_); + for (size_t i = 0; i < input_size; i++) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i); + auto input_node = input_node_with_index.first; + MS_EXCEPTION_IF_NULL(input_node); + InferShapeForNopNode(&input_node); + } +} + +void DynamicKernel::InferShapeForNopNode(AnfNodePtr *input_node) { + MS_EXCEPTION_IF_NULL(*input_node); + if (!opt::IsNopNode(*input_node) || !AnfAlgo::IsDynamicShape(*input_node)) { + MS_LOG(INFO) << "Input node is not a nop node, no need infer."; + return; + } + MS_LOG(INFO) << "Infer shape for nop node."; + std::stack nop_road; + nop_road.push(*input_node); + + while (true) { + auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0); + auto in_node = input_node_with_idx.first; + MS_EXCEPTION_IF_NULL(in_node); + if (opt::IsNopNode(in_node)) { + nop_road.push(in_node); + *input_node = in_node; + } else { + break; + } + } + while (!nop_road.empty()) { + auto nop_node = nop_road.top(); + AnfAlgo::InferShape(nop_node->cast()); + nop_road.pop(); + } +} } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h index 43a438b9fe..c70889c6b0 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h @@ -52,6 +52,8 @@ class DynamicKernel { protected: void RebuildDependTensor(); + void InferShapeRecursive(); + void InferShapeForNopNode(AnfNodePtr *input_node); void *stream_; const CNodePtr cnode_ptr_;