diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 490e4379c8..9ac46eaa82 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -51,26 +51,55 @@ static std::shared_ptr> python_paras; void ClearPythonParasMap() { python_paras = nullptr; } namespace { const int kSummaryGetItem = 2; -bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { +const size_t max_depth = 128; +bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx, bool *check_dynamic) { MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(node); - auto node_users = manager->node_users()[node]; - for (auto item : node_users) { - if (AnfAlgo::IsRealKernel(item.first)) { + if (*check_dynamic) { + if (node->isa() && AnfAlgo::IsNodeDynamicShape(node->cast())) { + return true; + } + } else if (AnfAlgo::IsRealKernel(node)) { + return true; + } + (*idx) += 1; + // max recursion depth + if (*idx <= max_depth) { + auto users = manager->node_users()[node]; + if (std::any_of(users.begin(), users.end(), [&](const std::pair &kernel) { + return RecursiveCheck(manager, kernel.first, idx, check_dynamic); + })) { return true; } } return false; } +bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + auto node_users = manager->node_users()[node]; + size_t idx = 0; + bool check_dynamic = false; + if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair &kernel) { + return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); + })) { + return true; + } + + return false; +} + bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(node); auto node_users = manager->node_users()[node]; - for (auto item : node_users) { - if (item.first->isa() && AnfAlgo::IsNodeDynamicShape(item.first->cast())) { - return true; - } + size_t idx = 0; + bool check_dynamic = true; + if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair &kernel) { + return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); + })) { + return true; } return false; } diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc index 59b82717c5..93d89dcc73 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc @@ -17,6 +17,7 @@ #include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h" #include #include +#include #include #include "runtime/mem.h" #include "runtime/kernel.h" @@ -27,6 +28,7 @@ namespace mindspore { namespace device { namespace ascend { +std::set kComputeDepend = {"Unique"}; AiCpuDynamicKernel::~AiCpuDynamicKernel() { // free dev ptr if (ext_info_addr_dev_ == nullptr) { @@ -67,9 +69,11 @@ void AiCpuDynamicKernel::Initialize() { output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); UnknowShapeOpType shape_type = UnknowShapeOpType::DEPEND_IN_SHAPE; - if (AnfAlgo::GetCNodeName(cnode_ptr_) == "Unique") { + auto op_name = AnfAlgo::GetCNodeName(cnode_ptr_); + if (kComputeDepend.find(op_name) != kComputeDepend.end()) { shape_type = UnknowShapeOpType::DEPEND_COMPUTE; } + unknow_type_ = shape_type; // Parse aicpu ext info if (is_dynamic_shape_) { MS_EXCEPTION_IF_NULL(cnode_ptr_); @@ -141,7 +145,7 @@ bool AiCpuDynamicKernel::UpdateExtInfo() { ext_info_handler_->UpdateInputShapeAndType(i, NOT_NULL(cnode_ptr_)); } - if (unknow_type_ != DEPEND_COMPUTE) { + if (AnfAlgo::IsDynamicShape(cnode_ptr_) && unknow_type_ != DEPEND_COMPUTE) { for (size_t i = 0; i < output_num_; ++i) { ext_info_handler_->UpdateOutputShapeAndType(i, NOT_NULL(cnode_ptr_)); } @@ -198,6 +202,9 @@ bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() { void AiCpuDynamicKernel::PostExecute() { MS_LOG(INFO) << "Aicpu " << cnode_ptr_->fullname_with_scope() << " PostExecute"; + if (unknow_type_ != DEPEND_COMPUTE) { + return; + } if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; return; diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h index 5eedf097b8..58f1659c21 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h @@ -40,7 +40,7 @@ class AiCpuDynamicKernel : public DynamicKernel { ext_info_size_(0), input_num_(0), output_num_(0), - unknow_type_(DEPEND_COMPUTE) {} + unknow_type_(DEPEND_IN_SHAPE) {} ~AiCpuDynamicKernel() override;