You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
6.4 KiB
152 lines
6.4 KiB
/**
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "kernel/kernel_query.h"
|
|
#include <memory>
|
|
#include <algorithm>
|
|
#include "kernel/aicpu/aicpu_kernel_metadata.h"
|
|
#include "kernel/rts/rt_kernel_info.h"
|
|
#include "kernel/hccl/hccl_kernel_metadata.h"
|
|
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
|
|
#include "kernel/akg/akg_kernel_metadata.h"
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
namespace mindspore {
|
|
namespace kernel {
|
|
namespace {
|
|
void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
|
|
(void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
|
|
[&kernel_node](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
|
|
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
|
|
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
|
|
});
|
|
if (!filtered_list.empty()) {
|
|
kernel_info_list->clear();
|
|
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
|
|
} else {
|
|
MS_LOG(INFO) << "All kernel Info list does not match any kernel info ";
|
|
for (size_t index = 0; index < kernel_info_list->size(); ++index) {
|
|
std::ostringstream buffer;
|
|
auto kernel_info = kernel_info_list->at(index);
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) {
|
|
buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
|
|
<< " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
|
|
} else {
|
|
buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]"
|
|
<< " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]";
|
|
}
|
|
MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
|
|
}
|
|
kernel_info_list->clear();
|
|
MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : ["
|
|
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
|
|
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
void KernelQueryAll(const CNodePtr &kernel_node,
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
|
|
|
TbeMetadataInfo(kernel_node, kernel_info_list);
|
|
|
|
if (kernel_info_list->empty()) {
|
|
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
|
if (!kernel_info_list->empty()) {
|
|
MS_LOG(INFO) << "The node [" << kernel_node->DebugString()
|
|
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
|
|
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
|
|
}
|
|
}
|
|
|
|
if (kernel_info_list->empty()) {
|
|
GetRtKelInfo(kernel_node, kernel_info_list);
|
|
}
|
|
|
|
if (kernel_info_list->empty()) {
|
|
HcclMetadataInfo(kernel_node, kernel_info_list);
|
|
}
|
|
if (kernel_info_list->empty()) {
|
|
MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!";
|
|
}
|
|
}
|
|
|
|
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
|
|
KernelType kernel_type) {
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
|
|
|
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
|
|
|
switch (kernel_type) {
|
|
case KernelType::AKG_KERNEL:
|
|
AkgMetadataInfo(kernel_node, kernel_info_list);
|
|
break;
|
|
default:
|
|
KernelQueryAll(kernel_node, kernel_info_list);
|
|
break;
|
|
}
|
|
|
|
if (kernel_info_list->empty()) {
|
|
MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!";
|
|
}
|
|
// check output
|
|
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
|
}
|
|
|
|
void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
|
kernel_info_list->clear();
|
|
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
|
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
|
}
|
|
bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
|
auto cnode = kernel_node->cast<CNodePtr>();
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
AICPUQuery(cnode, &kernel_info_list);
|
|
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
|
|
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
|
|
MS_EXCEPTION_IF_NULL(item);
|
|
return *item == *select_kernel_build_info;
|
|
});
|
|
}
|
|
|
|
bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
|
auto cnode = kernel_node->cast<CNodePtr>();
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
TbeMetadataInfo(cnode, &kernel_info_list);
|
|
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
|
|
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
|
|
MS_EXCEPTION_IF_NULL(item);
|
|
return *item == *select_kernel_build_info;
|
|
});
|
|
}
|
|
} // namespace kernel
|
|
} // namespace mindspore
|