support fp64 in huber_loss cuda kernel (#26583)

test_feature_precision_test_c
Guanghua Yu 5 years ago committed by GitHub
parent 90e6819cf2
commit 8645591d66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,7 +16,9 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
huber_loss,
ops::HuberLossKernel<paddle::platform::CUDADeviceContext, float>);
ops::HuberLossKernel<paddle::platform::CUDADeviceContext, float>,
ops::HuberLossKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
huber_loss_grad,
ops::HuberLossGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::HuberLossGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::HuberLossGradKernel<paddle::platform::CUDADeviceContext, double>);

Loading…
Cancel
Save