|
|
@ -33,8 +33,10 @@ bool SigmoidCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::Ad
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
if (dtype_ == kNumberTypeFloat16) {
|
|
|
|
if (dtype_ == kNumberTypeFloat16) {
|
|
|
|
LaunchKernel<float16>(inputs, outputs);
|
|
|
|
LaunchKernel<float16>(inputs, outputs);
|
|
|
|
} else if (dtype_ == kNumberTypeFloat32) {
|
|
|
|
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) {
|
|
|
|
LaunchKernel<float>(inputs, outputs);
|
|
|
|
LaunchKernel<float>(inputs, outputs);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|