fix unnecessary sync

pull/8849/head
baihuawei 4 years ago
parent 7125f1f8c8
commit e6fb4b9f69

@ -33,8 +33,10 @@ bool SigmoidCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::Ad
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) { if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs); LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) { } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64";
} }
return true; return true;
} }

@ -33,8 +33,10 @@ bool SigmoidCrossEntropyWithLogitsGradCPUKernel::Launch(const std::vector<kernel
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) { if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs); LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) { } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64";
} }
return true; return true;
} }

@ -242,7 +242,8 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker
auto tensor_address = tensor->device_address(); auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address) { if (tensor_address != nullptr && tensor_address != address &&
std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU) {
tensor->data_sync(false); tensor->data_sync(false);
} }
if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) {

@ -234,7 +234,7 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
operator_info << ") "; operator_info << ") ";
} }
operator_info << "is not support."; operator_info << "is not support.";
MS_LOG(EXCEPTION) << operator_info.str(); MS_EXCEPTION(TypeError) << operator_info.str();
} }
} // namespace } // namespace
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr, bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,

Loading…
Cancel
Save