|
|
|
@ -31,12 +31,16 @@ namespace {
|
|
|
|
|
void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
|
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
|
|
|
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
|
|
|
|
|
(void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
|
|
|
|
|
[&kernel_node](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
|
|
|
|
|
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
|
|
|
|
|
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
|
|
|
|
|
});
|
|
|
|
|
(void)std::copy_if(
|
|
|
|
|
kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
|
|
|
|
|
[output_tensor_num, input_tensor_num](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
|
|
|
|
|
return kernel_build_info->GetOutputNum() == output_tensor_num &&
|
|
|
|
|
kernel_build_info->GetInputNum() == input_tensor_num;
|
|
|
|
|
});
|
|
|
|
|
if (!filtered_list.empty()) {
|
|
|
|
|
kernel_info_list->clear();
|
|
|
|
|
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
|
|
|
|
@ -44,21 +48,20 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
|
|
|
|
MS_LOG(INFO) << "All kernel Info list does not match any kernel info ";
|
|
|
|
|
for (size_t index = 0; index < kernel_info_list->size(); ++index) {
|
|
|
|
|
std::ostringstream buffer;
|
|
|
|
|
auto kernel_info = kernel_info_list->at(index);
|
|
|
|
|
auto &kernel_info = kernel_info_list->at(index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) {
|
|
|
|
|
buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
|
|
|
|
|
if (kernel_info->GetOutputNum() != output_tensor_num) {
|
|
|
|
|
buffer << "Kernel node's output size [" << output_tensor_num << "]"
|
|
|
|
|
<< " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
|
|
|
|
|
} else {
|
|
|
|
|
buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]"
|
|
|
|
|
buffer << "Kernel node's output size [" << input_tensor_num << "]"
|
|
|
|
|
<< " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
|
|
|
|
|
}
|
|
|
|
|
kernel_info_list->clear();
|
|
|
|
|
MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : ["
|
|
|
|
|
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
|
|
|
|
|
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
|
|
|
|
|
MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" << output_tensor_num << "]"
|
|
|
|
|
<< "input size : [" << input_tensor_num << "] cannot match any kernelInfo !";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|