|
|
|
@ -134,8 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
|
|
|
|
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
|
|
|
|
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
|
|
|
|
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
|
|
|
|
// ------------------- cudnn conv forward ---------------------
|
|
|
|
// ------------------- cudnn conv forward ---------------------
|
|
|
|
T alpha = static_cast<T>(1.0f);
|
|
|
|
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
|
|
|
|
T beta = static_cast<T>(0.0f);
|
|
|
|
beta = 0.0f;
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
|
|
|
|
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
|
|
|
|
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
|
|
|
|
@ -321,7 +321,7 @@ namespace plat = paddle::platform;
|
|
|
|
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
|
|
|
|
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
|
|
|
|
paddle::operators::CUDNNConvOpKernel<float>,
|
|
|
|
paddle::operators::CUDNNConvOpKernel<float>,
|
|
|
|
paddle::operators::CUDNNConvOpKernel<double>,
|
|
|
|
paddle::operators::CUDNNConvOpKernel<double>,
|
|
|
|
paddle::operators::CUDNNConvOpKernel < plat::float16);
|
|
|
|
paddle::operators::CUDNNConvOpKernel<plat::float16>);
|
|
|
|
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
|
|
|
|
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
|
|
|
|
paddle::operators::CUDNNConvGradOpKernel<float>,
|
|
|
|
paddle::operators::CUDNNConvGradOpKernel<float>,
|
|
|
|
paddle::operators::CUDNNConvGradOpKernel<double>);
|
|
|
|
paddle::operators::CUDNNConvGradOpKernel<double>);
|
|
|
|
|