|
|
|
@ -270,9 +270,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
namespace plat = paddle::platform;
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
batch_norm,
|
|
|
|
|
ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>);
|
|
|
|
|
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
|
|
|
|
|
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
batch_norm_grad,
|
|
|
|
|
ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
|
|
|
|
|
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);
|
|
|
|
|