diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index 66e3f2ad33..ee0d837cee 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -37,11 +37,11 @@ class SupportedChecker { public: SupportedChecker() = default; virtual ~SupportedChecker() = default; - virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, + virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); } - virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, + virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); } diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc index 5b5bf7e4fc..cfa4e42342 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc @@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph return nullptr; } auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); - if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) { - return node; - } else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) { + if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) { + return nullptr; + } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) { auto builder = std::make_shared(kernel_builder_info); builder->SetKernelType(AICPU_KERNEL); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); @@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" << node->DebugString() << "]"; } - return node; + return nullptr; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc index 9abef8fa70..95bcb9f210 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc @@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod auto indices_const = CreateValueNode(new_cnode); new_cnode->add_input(indices_const); MS_EXCEPTION_IF_NULL(supported_checker_); - if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { + if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { MS_LOG(INFO) << "split topk failed, check to aicpu."; return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc index 1651718703..e45fc2637f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc @@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); auto new_fusion_transdata = std::make_shared(kTransDataOpName); - if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) { + if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) { std::vector inputs = {NewValueNode(new_fusion_transdata), utils::cast((*equiv)[input_varptr_])}; auto new_node = func_graph->NewCNode(inputs); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 6cc68457e5..09ea32becb 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { } MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); } + +bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} + +bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 10ae5282e0..bab867a3ef 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm { static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); static bool IsSwitchCall(const CNodePtr &call_node); + static bool IsScalarInput(const CNodePtr &cnode, size_t index); + static bool IsScalarOutput(const CNodePtr &cnode, size_t index); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index ff2ba05c84..b2771f4b9b 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other"; // some size const size_t kShape4dDims = 4; +const size_t kShape2dDims = 2; const size_t kShape5dDims = 5; +const size_t kShape1dDims = 1; const size_t kCubeSize = 16; const size_t kMemAlignSize = 512; const int kParameterDataTensorMask = 0; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index 4cee3577ed..b09268aa66 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker { public: MockSupportedChecker() = default; ~MockSupportedChecker() override = default; - bool CheckAiCoreSupported(const AnfNodePtr &anf_node, - const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { return true; } }; // namespace opt diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index 8bb9de7c7d..98dc9e9efc 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker { public: MockSupportedChecker() = default; ~MockSupportedChecker() override = default; - bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { return true; } };