|
|
|
@ -224,10 +224,10 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
|
|
|
|
|
|
|
|
|
|
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
|
|
|
|
|
const std::map<TypeId, TypeId> &type_map) {
|
|
|
|
|
// filte kernel info that unsupported raise or reduce datatype
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
|
|
|
|
size_t flag_in = 0;
|
|
|
|
|
size_t flag_out = 0;
|
|
|
|
|
bool flag = false;
|
|
|
|
|
for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) {
|
|
|
|
|
auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
|
|
|
|
auto device_dtype = kernel_build_info->GetInputDeviceType(input_index);
|
|
|
|
@ -235,11 +235,17 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
|
|
|
|
|
device_dtype = kNumberTypeFloat32;
|
|
|
|
|
}
|
|
|
|
|
auto iter = type_map.find(in_dtype);
|
|
|
|
|
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
|
|
|
|
|
if (iter == type_map.end() && in_dtype != device_dtype) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
|
|
|
|
|
// or infer dtype not equal kernel info dtype, return false
|
|
|
|
|
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (iter == type_map.end() && in_dtype != device_dtype) {
|
|
|
|
|
flag_in += 1;
|
|
|
|
|
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
|
|
|
|
|
flag = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -250,15 +256,22 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
|
|
|
|
|
device_dtype = kNumberTypeFloat32;
|
|
|
|
|
}
|
|
|
|
|
auto iter = type_map.find(in_dtype);
|
|
|
|
|
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
|
|
|
|
|
if (iter == type_map.end() && in_dtype != device_dtype) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
|
|
|
|
|
// or infer dtype not equal kernel info dtype, return false
|
|
|
|
|
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (iter == type_map.end() && in_dtype != device_dtype) {
|
|
|
|
|
flag_out += 1;
|
|
|
|
|
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
|
|
|
|
|
flag = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (flag_in == kernel_build_info->GetInputNum() || flag_out == kernel_build_info->GetOutputNum()) {
|
|
|
|
|
return false;
|
|
|
|
|
if (flag) {
|
|
|
|
|
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
|
|
|
|
MS_LOG(WARNING) << "node:[" << node_name << "]reduce precision from int64 to int32";
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|