|
|
|
@ -106,7 +106,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
|
|
|
|
bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(best_item);
|
|
|
|
|
if (cur_item.size() != best_item->size()) {
|
|
|
|
|
MS_LOG(ERROR) << "item size should be same!";
|
|
|
|
|
MS_LOG(ERROR) << "Item size should be same!";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// Update the best_item by comparing the cur_item and best_item
|
|
|
|
@ -280,8 +280,12 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat
|
|
|
|
|
|
|
|
|
|
bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index,
|
|
|
|
|
const std::vector<int> &node_mix_precision_datatype_index) {
|
|
|
|
|
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
|
|
|
|
|
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
|
|
|
|
|
auto check_index_tmp = IntToSize(check_index);
|
|
|
|
|
if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) {
|
|
|
|
|
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
|
|
|
|
|
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
|
|
|
|
@ -300,10 +304,10 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci
|
|
|
|
|
if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
|
|
|
|
|
auto find_iter = kernel_support_datatypes.find(iter->first);
|
|
|
|
|
if (find_iter == kernel_support_datatypes.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first;
|
|
|
|
|
MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first;
|
|
|
|
|
}
|
|
|
|
|
if (i >= find_iter->second.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node index " << i << " >= kernel datatype size " << find_iter->second.size();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size();
|
|
|
|
|
}
|
|
|
|
|
if (node_mix_precision_datatype[i] != find_iter->second[i]) {
|
|
|
|
|
iter = kernel_match_datatype_idx->erase(iter);
|
|
|
|
@ -314,7 +318,7 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci
|
|
|
|
|
}
|
|
|
|
|
auto datatype_indexes = iter->second;
|
|
|
|
|
if (i >= datatype_indexes.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
|
|
|
|
|
}
|
|
|
|
|
if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) {
|
|
|
|
|
iter = kernel_match_datatype_idx->erase(iter);
|
|
|
|
@ -384,9 +388,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
|
|
|
|
|
std::ostringstream buffer;
|
|
|
|
|
buffer << cnode->DebugString();
|
|
|
|
|
if (precision_reduce) {
|
|
|
|
|
buffer << " reduce precision, node datatype: \n";
|
|
|
|
|
buffer << " Reduce precision, node datatype: \n";
|
|
|
|
|
} else {
|
|
|
|
|
buffer << " raise precision, node datatype: \n";
|
|
|
|
|
buffer << " Raise precision, node datatype: \n";
|
|
|
|
|
}
|
|
|
|
|
PrintInputAndOutputInferType(buffer, cnode);
|
|
|
|
|
buffer << ", select kernel:" << selected_kernel_build_info->ToString();
|
|
|
|
@ -554,12 +558,12 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern
|
|
|
|
|
if (select_status == kNoMatched) {
|
|
|
|
|
std::ostringstream buffer;
|
|
|
|
|
PrintInputAndOutputInferType(buffer, kernel_node);
|
|
|
|
|
MS_LOG(WARNING) << ">>> candidates kernel info list:";
|
|
|
|
|
MS_LOG(WARNING) << ">>> Candidates kernel info list:";
|
|
|
|
|
for (size_t index = 0; index < kernel_info_list.size(); ++index) {
|
|
|
|
|
MS_LOG(WARNING) << "kernel [" << index << "] :" << kernel_info_list[index]->ToString();
|
|
|
|
|
MS_LOG(WARNING) << "Kernel [" << index << "] :" << kernel_info_list[index]->ToString();
|
|
|
|
|
}
|
|
|
|
|
for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) {
|
|
|
|
|
MS_LOG(WARNING) << "kernel [" << (kernel_info_list.size() + index)
|
|
|
|
|
MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index)
|
|
|
|
|
<< "] :" << aicpu_kernel_info_list[index]->ToString();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(WARNING) << " <<<";
|
|
|
|
|