|
|
@ -342,7 +342,7 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB
|
|
|
|
std::vector<int> *node_mix_precision_datatype_index) {
|
|
|
|
std::vector<int> *node_mix_precision_datatype_index) {
|
|
|
|
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
|
|
|
|
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
|
|
|
|
bool add_node_datatype_flag = false;
|
|
|
|
bool add_node_datatype_flag = false;
|
|
|
|
if (node_mix_precision_datatype->size() == 0) {
|
|
|
|
if (node_mix_precision_datatype->empty()) {
|
|
|
|
add_node_datatype_flag = true;
|
|
|
|
add_node_datatype_flag = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
|
|
|
|
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
|
|
|
@ -464,8 +464,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
|
|
void SelectKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
int SelectKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
|
|
|
|
|
|
|
int status = kStatusAllMatched;
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
bool precision_reduce = false;
|
|
|
|
bool precision_reduce = false;
|
|
|
|
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
|
|
|
|
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
|
|
|
@ -486,11 +487,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
<< "] cannot find valid kernel info, not supported the type" << buffer.str();
|
|
|
|
<< "] cannot find valid kernel info, not supported the type" << buffer.str();
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
|
|
|
|
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
|
|
|
|
|
|
|
|
status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
|
|
|
// Set format and data type for input tensor.
|
|
|
|
// Set format and data type for input tensor.
|
|
|
|
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
|
|
|
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
|
|
|
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
|
|
|
|
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
|
|
|
|